diff --git a/.gitignore b/.gitignore index 19647a0..ff9a13d 100644 --- a/.gitignore +++ b/.gitignore @@ -26,4 +26,7 @@ configs/*.yaml !configs/ltxv_2b_full.yaml !configs/ltxv_2b_lora.yaml !configs/ltxv_2b_lora_low_vram.yaml -outputs +outputs/ +datasets/ +training_data/ +*.db diff --git a/scripts/app_gradio_v2.py b/scripts/app_gradio_v2.py new file mode 100644 index 0000000..40690ac --- /dev/null +++ b/scripts/app_gradio_v2.py @@ -0,0 +1,2221 @@ +"""Gradio interface for LTX Video Trainer.""" + +import datetime +import json +import logging +import os +import shutil +from dataclasses import dataclass +from datetime import timezone +from pathlib import Path + +import gradio as gr +import torch +import yaml +from huggingface_hub import login + +from ltxv_trainer.captioning import ( + DEFAULT_VLM_CAPTION_INSTRUCTION, + CaptionerType, + create_captioner, +) +from ltxv_trainer.hf_hub_utils import convert_video_to_gif +from ltxv_trainer.model_loader import ( + LtxvModelVersion, +) +from ltxv_trainer.trainer import LtxvTrainer, LtxvTrainerConfig +from scripts.dataset_manager import DatasetManager +from scripts.jobs.database import JobDatabase, JobStatus +from scripts.preprocess_dataset import preprocess_dataset +from scripts.process_videos import parse_resolution_buckets + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" + +torch.cuda.empty_cache() +BASE_DIR = Path(__file__).parent +PROJECT_ROOT = BASE_DIR.parent +OUTPUTS_DIR = PROJECT_ROOT / "outputs" +TRAINING_DATA_DIR = PROJECT_ROOT / "training_data" +DATASETS_DIR = PROJECT_ROOT / "datasets" +VALIDATION_SAMPLES_DIR = OUTPUTS_DIR / "validation_samples" + +OUTPUTS_DIR.mkdir(exist_ok=True) +TRAINING_DATA_DIR.mkdir(exist_ok=True) +VALIDATION_SAMPLES_DIR.mkdir(exist_ok=True) +DATASETS_DIR.mkdir(exist_ok=True) + + +@dataclass +class TrainingConfigParams: + """Parameters for generating training configuration.""" + + model_source: str + learning_rate: float + steps: int + lora_rank: int + batch_size: int + validation_prompt: str + video_dims: tuple[int, int, int] + validation_interval: int = 100 + push_to_hub: bool = False + hub_model_id: str | None = None + + +@dataclass +class TrainingState: + """State for tracking training progress.""" + + status: str | None = None + progress: str | None = None + validation: str | None = None + download: str | None = None + error: str | None = None + hf_repo: str | None = None + checkpoint_path: str | None = None + + def reset(self) -> None: + """Reset state to initial values.""" + self.status = "running" + self.progress = None + self.validation = None + self.download = None + self.error = None + self.hf_repo = None + self.checkpoint_path = None + + def update(self, **kwargs: str | int | None) -> None: + """Update state with provided values.""" + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + + +@dataclass +class TrainingParams: + dataset_name: str + validation_prompt: str + learning_rate: float + steps: int + lora_rank: int + batch_size: int + model_source: str + width: int + height: int + num_frames: int + push_to_hub: bool + hf_model_id: str + hf_token: str | None = None + id_token: str | None = None + validation_interval: int = 100 + + +def _handle_validation_sample(step: int, video_path: Path) -> str | None: + """Handle validation sample conversion and storage. + + Args: + step: Current training step + video_path: Path to the validation video + + Returns: + Path to the GIF file if successful, None otherwise + """ + gif_path = VALIDATION_SAMPLES_DIR / f"sample_step_{step}.gif" + try: + convert_video_to_gif(video_path, gif_path) + logger.info(f"New validation sample converted to GIF at step {step}: {gif_path}") + return str(gif_path) + except Exception as e: + logger.error(f"Failed to convert validation video to GIF: {e}") + return None + + +def generate_training_config(params: TrainingConfigParams, training_data_dir: str) -> dict: + """Generate training configuration from parameters. + + Args: + params: Training configuration parameters + training_data_dir: Directory containing training data + + Returns: + Dictionary containing the complete training configuration + """ + template_path = Path(__file__).parent.parent / "configs" / "ltxv_13b_lora_template.yaml" + with open(template_path) as f: + config = yaml.safe_load(f) + + config["model"]["model_source"] = params.model_source + config["lora"]["rank"] = params.lora_rank + config["lora"]["alpha"] = params.lora_rank + config["optimization"]["learning_rate"] = params.learning_rate + config["optimization"]["steps"] = params.steps + config["optimization"]["batch_size"] = params.batch_size + config["data"]["preprocessed_data_root"] = str(training_data_dir) + config["output_dir"] = str(OUTPUTS_DIR / f"lora_r{params.lora_rank}") + + config["hub"] = { + "push_to_hub": params.push_to_hub, + "hub_model_id": params.hub_model_id if params.push_to_hub else None, + } + + width, height, num_frames = params.video_dims + config["validation"] = { + "prompts": [params.validation_prompt], + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", + "video_dims": [width, height, num_frames], + "seed": 42, + "inference_steps": 30, + "interval": params.validation_interval, + "videos_per_prompt": 1, + "guidance_scale": 3.5, + } + + if "validation" in config and "images" not in config["validation"]: + config["validation"]["images"] = None + + return config + + +class GradioUI: + """Class to manage Gradio UI components and state.""" + + def __init__(self) -> None: + self.training_state = TrainingState() + self.dataset_manager = DatasetManager(datasets_root=DATASETS_DIR) + self.current_dataset = None + + db_path = PROJECT_ROOT / "jobs.db" + self.job_db = JobDatabase(db_path) + self.current_job_id = None + + self.validation_prompt = None + self.status_output = None + self.progress_output = None + self.download_btn = None + self.hf_repo_link = None + + def reset_interface(self) -> dict: + """Reset the interface and clean up all training data. + + Returns: + Dictionary of Gradio component updates + """ + self.training_state.reset() + self.current_job_id = None + + if TRAINING_DATA_DIR.exists(): + shutil.rmtree(TRAINING_DATA_DIR) + TRAINING_DATA_DIR.mkdir(exist_ok=True) + + if OUTPUTS_DIR.exists(): + shutil.rmtree(OUTPUTS_DIR) + OUTPUTS_DIR.mkdir(exist_ok=True) + VALIDATION_SAMPLES_DIR.mkdir(exist_ok=True) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + return { + self.validation_prompt: gr.update( + value="a professional portrait video of a person with blurry bokeh background", + info="Include the LoRA ID token (e.g., <lora>) in this prompt if desired.", + ), + self.status_output: gr.update(value=""), + self.progress_output: gr.update(value=""), + self.download_btn: gr.update(visible=False), + self.hf_repo_link: gr.update(visible=False, value=""), + } + + def get_model_path(self) -> str | None: + """Get the path to the trained model file.""" + if self.training_state.download and Path(self.training_state.download).exists(): + return self.training_state.download + return None + + def update_progress(self) -> tuple[str, str, gr.update, str, gr.update]: + """Update the UI with current training progress.""" + # Default return values + job_display = "" + progress_msg = "" + file_update = gr.update(visible=False) + hf_html = "" + hf_update = gr.update(visible=False) + + # Handle current job from queue + if self.current_job_id: + job = self.job_db.get_job(self.current_job_id) + if job: + job_display = f"**Current Job:** #{self.current_job_id} | Dataset: `{job['dataset_name']}`" + status = job["status"] + progress = job.get("progress", "") + + status_messages = { + JobStatus.PENDING: "Waiting in queue...", + JobStatus.RUNNING: progress if progress else "Training in progress...", + JobStatus.COMPLETED: "Training completed!", + JobStatus.FAILED: "Training failed", + JobStatus.CANCELLED: "Job cancelled", + } + progress_msg = status_messages.get(status, "") + + # Handle completed job artifacts + if status == JobStatus.COMPLETED: + hf_repo = job.get("hf_repo_url", "") + checkpoint_path = job.get("checkpoint_path", "") + + if hf_repo: + hf_html = f'View model on HuggingFace Hub' + hf_update = gr.update(visible=True) + elif checkpoint_path and Path(checkpoint_path).exists(): + file_update = gr.update( + value=checkpoint_path, visible=True, label=f"Download {Path(checkpoint_path).name}" + ) + + return (job_display, progress_msg, file_update, hf_html, hf_update) + + # Handle legacy direct training mode + if self.training_state.status is not None: + job_display = "**Direct Training Mode** (Legacy)" + progress_msg = self.training_state.progress + status = self.training_state.status + + if status == "complete": + if self.training_state.hf_repo: + hf_html = ( + f'View model on HuggingFace Hub' + ) + hf_update = gr.update(visible=True) + elif self.training_state.download and Path(self.training_state.download).exists(): + file_update = gr.update( + value=self.training_state.download, + visible=True, + label=f"Download {Path(self.training_state.download).name}", + ) + + return (job_display, progress_msg, file_update, hf_html, hf_update) + + # No active training + return (job_display, progress_msg, file_update, hf_html, hf_update) + + def _save_checkpoint(self, saved_path: Path, trainer_config: LtxvTrainerConfig) -> tuple[Path, str | None]: + """Save and copy the checkpoint to a permanent location. + + Args: + saved_path: Path where the checkpoint was initially saved + trainer_config: Training configuration + + Returns: + Tuple of (permanent checkpoint path, HF repo URL if applicable) + """ + permanent_checkpoint_dir = OUTPUTS_DIR / "checkpoints" + permanent_checkpoint_dir.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + checkpoint_filename = f"comfy_lora_checkpoint_{timestamp}.safetensors" + permanent_checkpoint_path = permanent_checkpoint_dir / checkpoint_filename + + try: + shutil.copy2(saved_path, permanent_checkpoint_path) + logger.info(f"Checkpoint copied to permanent location: {permanent_checkpoint_path}") + except Exception as e: + logger.error(f"Failed to copy checkpoint: {e}") + permanent_checkpoint_path = saved_path + + hf_repo = ( + f"https://huggingface.co/{trainer_config.hub.hub_model_id}" if trainer_config.hub.hub_model_id else None + ) + + return permanent_checkpoint_path, hf_repo + + def _preprocess_dataset( + self, + dataset_file: Path, + model_source: str, + width: int, + height: int, + num_frames: int, + id_token: str | None = None, + ) -> tuple[bool, str | None]: + """Preprocess the dataset by computing video latents and text embeddings. + + Args: + dataset_file: Path to the dataset.json file + model_source: Model source identifier + width: Video width + height: Video height + num_frames: Number of frames + id_token: Optional token to prepend to captions (for LoRA training) + + Returns: + Tuple of (success, error_message) + """ + try: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + precomputed_dir = TRAINING_DATA_DIR / ".precomputed" + if precomputed_dir.exists(): + shutil.rmtree(precomputed_dir) + + resolution_buckets = f"{width}x{height}x{num_frames}" + parsed_buckets = parse_resolution_buckets(resolution_buckets) + + preprocess_dataset( + dataset_file=str(dataset_file), + caption_column="caption", + video_column="media_path", + resolution_buckets=parsed_buckets, + batch_size=1, + output_dir=None, + id_token=id_token, + vae_tiling=False, + decode_videos=True, + model_source=model_source, + device="cuda" if torch.cuda.is_available() else "cpu", + load_text_encoder_in_8bit=False, + ) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return True, None + + except Exception as e: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return False, f"Error preprocessing dataset: {e!s}" + + def _should_preprocess_data( + self, + width: int, + height: int, + num_frames: int, + videos: list[str], + ) -> bool: + """Check if data needs to be preprocessed based on resolution changes. + + Args: + width: Video width + height: Video height + num_frames: Number of frames + videos: List of video file paths + + Returns: + True if preprocessing is needed, False otherwise + """ + resolution_file = TRAINING_DATA_DIR / ".resolution_config" + current_resolution = f"{width}x{height}x{num_frames}" + needs_to_copy = False + for video in videos: + if Path(video).exists(): + needs_to_copy = True + if needs_to_copy: + logger.info("Videos provided, will copy them to training directory.") + return True, needs_to_copy + + if not resolution_file.exists() or not (TRAINING_DATA_DIR / "captions.json").exists(): + return True, needs_to_copy + + try: + with open(resolution_file) as f: + previous_resolution = f.read().strip() + return previous_resolution != current_resolution, needs_to_copy + except Exception: + return True, needs_to_copy + + def _save_resolution_config( + self, + width: int, + height: int, + num_frames: int, + ) -> None: + """Save current resolution configuration. + + Args: + width: Video width + height: Video height + num_frames: Number of frames + """ + resolution_file = TRAINING_DATA_DIR / ".resolution_config" + current_resolution = f"{width}x{height}x{num_frames}" + + with open(resolution_file, "w") as f: + f.write(current_resolution) + + def _sync_captions_from_ui( + self, params: TrainingParams, training_captions_file: Path + ) -> tuple[dict[str, str] | None, str | None]: + """Sync captions from the UI to captions.json. Returns (captions_data, error_message).""" + if params.captions_json: + try: + dataset = json.loads(params.captions_json) + # Convert list of dicts to captions_data dict + captions_data = {item["media_path"]: item["caption"] for item in dataset} + # Save to captions.json (overwrite every time) + with open(training_captions_file, "w") as f: + json.dump(captions_data, f, indent=2) + return captions_data, None + except Exception as e: + return None, f"Invalid captions JSON: {e!s}" + else: + return None, "No captions found in the UI. Please process videos first." + + # ruff: noqa: PLR0912 + def start_training( + self, + params: TrainingParams, + ) -> tuple[str, gr.update]: + """Queue a training job.""" + # Validate dataset exists + if not params.dataset_name: + return "Please select a dataset from the Datasets tab", gr.update(interactive=True) + + managed_dataset_dir = self.dataset_manager.datasets_root / params.dataset_name + managed_dataset_json = managed_dataset_dir / "dataset.json" + + if not managed_dataset_json.exists(): + return f"Dataset '{params.dataset_name}' not found or has no videos", gr.update(interactive=True) + + # Check if dataset is preprocessed + precomputed_dir = managed_dataset_dir / ".precomputed" + if not precomputed_dir.exists(): + return ( + f"Dataset '{params.dataset_name}' is not preprocessed. Please preprocess it in the Datasets tab first.", + gr.update(interactive=True), + ) + + # Create job parameters + job_params = { + "model_source": params.model_source, + "learning_rate": params.learning_rate, + "steps": params.steps, + "lora_rank": params.lora_rank, + "batch_size": params.batch_size, + "width": params.width, + "height": params.height, + "num_frames": params.num_frames, + "id_token": params.id_token or "", + "validation_prompt": params.validation_prompt, + "validation_interval": params.validation_interval, + "push_to_hub": params.push_to_hub, + "hf_model_id": params.hf_model_id if params.push_to_hub else "", + "hf_token": params.hf_token if params.push_to_hub else "", + } + + # Create job in database + try: + job_id = self.job_db.create_job(params.dataset_name, job_params) + self.current_job_id = job_id + + return ( + f"Job #{job_id} created and queued for training on dataset '{params.dataset_name}'.\n" + f"The worker will start training automatically. Check the Queue tab for status.", + gr.update(interactive=True), + ) + except Exception as e: + return f"Failed to create training job: {e}", gr.update(interactive=True) + + def start_training_direct( + self, + params: TrainingParams, + ) -> tuple[str, gr.update]: + """Start the training process directly (legacy method).""" + if params.hf_token: + login(token=params.hf_token) + + try: + # Clear any existing CUDA cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + # Set training status + self.training_state.reset() # This sets status to "running" + + # Prepare data directory + data_dir = TRAINING_DATA_DIR + data_dir.mkdir(exist_ok=True) + + # Validate dataset exists + if not params.dataset_name: + return "Please select a dataset from the Datasets tab", gr.update(interactive=True) + + # Use managed dataset + managed_dataset_dir = self.dataset_manager.datasets_root / params.dataset_name + managed_dataset_json = managed_dataset_dir / "dataset.json" + + if not managed_dataset_json.exists(): + return f"Dataset '{params.dataset_name}' not found or has no videos", gr.update(interactive=True) + + # Copy dataset.json to training directory + training_captions_file = data_dir / "captions.json" + shutil.copy2(managed_dataset_json, training_captions_file) + + # Load the dataset to get video paths + with open(training_captions_file) as f: + dataset = json.load(f) + + # Copy videos from managed dataset to training directory + for item in dataset: + src_video = managed_dataset_dir / item["media_path"] + dest_video = data_dir / Path(item["media_path"]).name + + if not dest_video.exists() and src_video.exists(): + shutil.copy2(src_video, dest_video) + + # Update the media_path to be relative to data_dir + item["media_path"] = Path(item["media_path"]).name + + # Save updated dataset with corrected paths + with open(training_captions_file, "w") as f: + json.dump(dataset, f, indent=2) + + # Check if preprocessing is needed + needs_preprocessing = self._should_preprocess_data(params.width, params.height, params.num_frames, [])[0] + + # Preprocess if needed (first time or resolution changed) + if needs_preprocessing: + # Clean up existing precomputed data + precomputed_dir = TRAINING_DATA_DIR / ".precomputed" + if precomputed_dir.exists(): + shutil.rmtree(precomputed_dir) + + success, error_msg = self._preprocess_dataset( + dataset_file=training_captions_file, + model_source=params.model_source, + width=params.width, + height=params.height, + num_frames=params.num_frames, + id_token=params.id_token, + ) + if not success: + return error_msg, gr.update(interactive=True) + + # Save current resolution config after successful preprocessing + self._save_resolution_config(params.width, params.height, params.num_frames) + + # Generate training config + config_params = TrainingConfigParams( + model_source=params.model_source, + learning_rate=params.learning_rate, + steps=params.steps, + lora_rank=params.lora_rank, + batch_size=params.batch_size, + validation_prompt=params.validation_prompt, + video_dims=(params.width, params.height, params.num_frames), + validation_interval=params.validation_interval, + push_to_hub=params.push_to_hub, + hub_model_id=params.hf_model_id if params.push_to_hub else None, + ) + + config = generate_training_config(config_params, str(data_dir)) + config_path = OUTPUTS_DIR / "train_config.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f, indent=4) + + # Run training + self.run_training(config_path) + + return "Training completed!", gr.update(interactive=True) + + except Exception as e: + return f"Error during training: {e!s}", gr.update(interactive=True) + finally: + # Clean up CUDA cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + + def run_training(self, config_path: Path) -> None: + """Run the training process and update progress.""" + # Reset training state at the start + self.training_state.reset() + + try: + # Load config from YAML + with open(config_path) as f: + config_dict = yaml.safe_load(f) + + # Initialize trainer config and trainer + trainer_config = LtxvTrainerConfig(**config_dict) + trainer = LtxvTrainer(trainer_config) + + def training_callback(step: int, total_steps: int, sampled_videos: list[Path] | None = None) -> None: + """Callback function to update training progress and show samples.""" + # Update progress + progress_pct = (step / total_steps) * 100 + self.training_state.update(progress=f"Step {step}/{total_steps} ({progress_pct:.1f}%)") + + # Update validation video at validation intervals + if step % trainer_config.validation.interval == 0 and sampled_videos: + # Convert the first sample to GIF + gif_path = _handle_validation_sample(step, sampled_videos[0]) + if gif_path: + self.training_state.update(validation=gif_path) + + logger.info("Starting training...") + + # Start training with callback + saved_path, stats = trainer.train(disable_progress_bars=False, step_callback=training_callback) + + # Save checkpoint and get paths + permanent_checkpoint_path, hf_repo = self._save_checkpoint(saved_path, trainer_config) + + # Update training outputs with completion status + self.training_state.update( + status="complete", + download=str(permanent_checkpoint_path), + hf_repo=hf_repo, + checkpoint_path=str(permanent_checkpoint_path), + ) + + logger.info(f"Training completed. Model saved to {permanent_checkpoint_path}") + logger.info(f"Training stats: {stats}") + + except Exception as e: + logger.error(f"Training failed: {e!s}", exc_info=True) + self.training_state.update(status="failed", error=str(e)) + raise + finally: + # Don't reset current_job here - let the UI handle it + if self.training_state.status == "running": + self.training_state.update(status="failed") + + def create_new_dataset(self, name: str) -> tuple[str, gr.Dropdown, list, dict, None, str, str, gr.Dropdown]: + """Create a new dataset. + + Args: + name: Name of the dataset + + Returns: + Tuple of (status message, updated dropdown, empty gallery, empty stats, cleared fields, training dropdown) + """ + if not name or not name.strip(): + return "Please enter a dataset name", gr.update(), gr.update(), {}, None, "", "", gr.update() + + try: + self.dataset_manager.create_dataset(name) + datasets = self.dataset_manager.list_datasets() + # Return updated dropdown with new dataset selected, and clear all other fields + return ( + f"Created dataset: {name}", + gr.update(choices=datasets, value=name), + [], # Empty gallery + {"name": name, "total_videos": 0, "captioned": 0, "uncaptioned": 0, "preprocessed": False}, # Stats + None, # Clear selected video + "", # Clear video name + "", # Clear caption editor + gr.update(choices=datasets), # Update training tab dropdown + ) + except Exception as e: + return f"Error: {e}", gr.update(), gr.update(), {}, None, "", "", gr.update() + + def load_dataset(self, dataset_name: str) -> tuple[str | None, list, dict, None | str, str, str]: + """Load a dataset and display its contents. + + Args: + dataset_name: Name of the dataset to load + + Returns: + Tuple of (dataset name, gallery items, statistics, cleared video/name/caption) + """ + if not dataset_name: + return None, [], {}, None, "", "" + + self.current_dataset = dataset_name + items = self.dataset_manager.get_dataset_items(dataset_name) + stats = self.dataset_manager.get_dataset_stats(dataset_name) + + # Prepare gallery items (thumbnails with captions) + gallery_items = [ + ( + item["thumbnail"], + item["caption"][:50] + "..." + if len(item.get("caption", "")) > 50 + else item.get("caption", "") or "No caption", + ) + for item in items + if item["thumbnail"] + ] + + # Clear the video preview and caption editor when switching datasets + return dataset_name, gallery_items, stats, None, "", "" + + def upload_videos_to_dataset(self, files: list, dataset_name: str) -> tuple[str, gr.Gallery, dict]: + """Upload videos to a dataset. + + Args: + files: List of video file paths + dataset_name: Name of the dataset + + Returns: + Tuple of (status message, updated gallery, updated stats) + """ + if not dataset_name: + return "Please select a dataset first", gr.update(), gr.update() + + if not files: + return "No files selected", gr.update(), gr.update() + + try: + result = self.dataset_manager.add_videos(dataset_name, files) + + message = f"Added {len(result['added'])} videos" + if result["failed"]: + message += f", {len(result['failed'])} failed" + for failed in result["failed"][:3]: # Show first 3 failures + message += f"\n- {failed['video']}: {failed['error']}" + + # Refresh gallery and stats + items = self.dataset_manager.get_dataset_items(dataset_name) + gallery_items = [ + ( + i["thumbnail"], + ( + i.get("caption", "")[:50] + "..." + if len(i.get("caption", "")) > 50 + else i.get("caption", "") or "No caption" + ), + ) + for i in items + if i["thumbnail"] + ] + stats = self.dataset_manager.get_dataset_stats(dataset_name) + + return message, gr.update(value=gallery_items), stats + except Exception as e: + return f"Error: {e}", gr.update(), gr.update() + + def upload_videos_with_references( + self, video_files: list, reference_files: list, dataset_name: str + ) -> tuple[str, gr.Gallery, dict]: + """Upload videos with reference videos to a dataset for IC-LoRA training. + + Args: + video_files: List of target video file paths + reference_files: List of reference video file paths + dataset_name: Name of the dataset + + Returns: + Tuple of (status message, updated gallery, updated stats) + """ + if not dataset_name: + return "Please select a dataset first", gr.update(), gr.update() + + if not video_files: + return "No target videos selected", gr.update(), gr.update() + + if not reference_files: + return "No reference videos selected", gr.update(), gr.update() + + try: + result = self.dataset_manager.add_videos(dataset_name, video_files, reference_files) + + message = f"Added {len(result['added'])} video pairs (target + reference)" + if result["failed"]: + message += f", {len(result['failed'])} failed" + for failed in result["failed"][:3]: # Show first 3 failures + message += f"\n- {failed['video']}: {failed['error']}" + + # Refresh gallery and stats + items = self.dataset_manager.get_dataset_items(dataset_name) + gallery_items = [ + ( + i["thumbnail"], + ( + i.get("caption", "")[:50] + "..." + if len(i.get("caption", "")) > 50 + else i.get("caption", "") or "No caption" + ), + ) + for i in items + if i["thumbnail"] + ] + stats = self.dataset_manager.get_dataset_stats(dataset_name) + + return message, gr.update(value=gallery_items), stats + except Exception as e: + return f"Error: {e}", gr.update(), gr.update() + + def select_video_from_gallery(self, evt: gr.SelectData, dataset_name: str) -> tuple[str | None, str, str]: + """Handle video selection from gallery. + + Args: + evt: Selection event data + dataset_name: Name of the dataset + + Returns: + Tuple of (video path, video name, caption) + """ + if not dataset_name: + return None, "", "" + + items = self.dataset_manager.get_dataset_items(dataset_name) + if evt.index >= len(items): + return None, "", "" + + selected_item = items[evt.index] + video_path = selected_item["full_video_path"] + video_name = Path(selected_item["media_path"]).name + caption = selected_item.get("caption", "") + + return video_path, video_name, caption + + def save_caption_edit(self, dataset_name: str, video_name: str, caption: str) -> tuple[str, dict]: + """Save edited caption for a video. + + Args: + dataset_name: Name of the dataset + video_name: Name of the video file + caption: New caption text + + Returns: + Tuple of (status message, updated stats) + """ + if not dataset_name or not video_name: + return "No video selected", gr.update() + + try: + self.dataset_manager.update_caption(dataset_name, video_name, caption) + stats = self.dataset_manager.get_dataset_stats(dataset_name) + return f"Caption saved for {video_name}", stats + except Exception as e: + return f"Error: {e}", gr.update() + + def auto_caption_single( + self, dataset_name: str, video_name: str, vlm_instruction: str, use_8bit: bool + ) -> tuple[str, str]: + """Auto-generate caption for a single video. + + Args: + dataset_name: Name of the dataset + video_name: Name of the video file + vlm_instruction: Custom VLM instruction for captioning + use_8bit: Whether to use 8-bit quantization + + Returns: + Tuple of (generated caption, status message) + """ + if not dataset_name or not video_name: + return "", "No video selected" + + try: + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Use custom instruction if provided, otherwise use default + instruction = ( + vlm_instruction if vlm_instruction and vlm_instruction.strip() else DEFAULT_VLM_CAPTION_INSTRUCTION + ) + + captioner = create_captioner( + captioner_type=CaptionerType.QWEN_25_VL, + use_8bit=use_8bit, + vlm_instruction=instruction, + device=device, + ) + + video_path = Path(self.dataset_manager.datasets_root) / dataset_name / "videos" / video_name + caption = captioner.caption(str(video_path)) + + self.dataset_manager.update_caption(dataset_name, video_name, caption) + return caption, f"Caption generated for {video_name}" + except Exception as e: + return "", f"Error: {e}" + + def auto_caption_all_uncaptioned(self, dataset_name: str, vlm_instruction: str, use_8bit: bool) -> tuple[str, dict]: + """Auto-generate captions for all uncaptioned videos. + + Args: + dataset_name: Name of the dataset + vlm_instruction: Custom VLM instruction for captioning + use_8bit: Whether to use 8-bit quantization + + Returns: + Tuple of (status message, updated stats) + """ + if not dataset_name: + return "Please select a dataset first", gr.update() + + try: + items = self.dataset_manager.get_dataset_items(dataset_name) + uncaptioned = [i for i in items if not i.get("caption") or not i.get("caption").strip()] + + if not uncaptioned: + stats = self.dataset_manager.get_dataset_stats(dataset_name) + return "All videos already have captions", stats + + device = "cuda" if torch.cuda.is_available() else "cpu" + + # Use custom instruction if provided, otherwise use default + instruction = ( + vlm_instruction if vlm_instruction and vlm_instruction.strip() else DEFAULT_VLM_CAPTION_INSTRUCTION + ) + + captioner = create_captioner( + captioner_type=CaptionerType.QWEN_25_VL, + use_8bit=use_8bit, + vlm_instruction=instruction, + device=device, + ) + + for _idx, item in enumerate(uncaptioned): + video_path = Path(self.dataset_manager.datasets_root) / dataset_name / item["media_path"] + caption = captioner.caption(str(video_path)) + video_name = Path(item["media_path"]).name + self.dataset_manager.update_caption(dataset_name, video_name, caption) + + stats = self.dataset_manager.get_dataset_stats(dataset_name) + return f"Generated captions for {len(uncaptioned)} videos", stats + except Exception as e: + return f"Error: {e}", gr.update() + + def validate_dataset_ui(self, dataset_name: str) -> dict: + """Validate a dataset and return issues. + + Args: + dataset_name: Name of the dataset + + Returns: + Validation results dictionary + """ + if not dataset_name: + return {"error": "Please select a dataset first"} + + try: + return self.dataset_manager.validate_dataset(dataset_name) + except Exception as e: + return {"error": str(e)} + + def preprocess_dataset_ui( + self, + dataset_name: str, + width: int, + height: int, + num_frames: int, + id_token: str, + decode_videos: bool, + load_text_encoder_8bit: bool, + ) -> tuple[str, dict]: + """Preprocess a dataset from the UI. + + Args: + dataset_name: Name of the dataset + width: Video width + height: Video height + num_frames: Number of frames + id_token: ID token to prepend to captions + decode_videos: Whether to decode videos for verification + load_text_encoder_8bit: Whether to load text encoder in 8-bit mode + + Returns: + Tuple of (status message, updated stats) + """ + if not dataset_name: + return "Please select a dataset first", gr.update() + + try: + dataset_dir = self.dataset_manager.datasets_root / dataset_name + dataset_json = dataset_dir / "dataset.json" + + if not dataset_json.exists(): + return "Dataset JSON not found", gr.update() + + # Check if dataset has reference videos + with open(dataset_json) as f: + dataset_items = json.load(f) + + has_references = any(item.get("reference_path") for item in dataset_items) + + resolution_buckets = f"{width}x{height}x{num_frames}" + parsed_buckets = parse_resolution_buckets(resolution_buckets) + + preprocess_dataset( + dataset_file=str(dataset_json), + caption_column="caption", + video_column="media_path", + resolution_buckets=parsed_buckets, + batch_size=1, + output_dir=None, # Will use default .precomputed + id_token=id_token if id_token and id_token.strip() else None, + vae_tiling=False, + decode_videos=decode_videos, + model_source=LtxvModelVersion.latest(), + device="cuda" if torch.cuda.is_available() else "cpu", + load_text_encoder_in_8bit=load_text_encoder_8bit, + reference_column="reference_path" if has_references else None, + ) + + stats = self.dataset_manager.get_dataset_stats(dataset_name) + message = "Preprocessing complete!" + if has_references: + message += " (IC-LoRA reference videos processed)" + if decode_videos: + decoded_dir = dataset_dir / ".precomputed" / "decoded_videos" + message += f"\nDecoded videos saved to: {decoded_dir}" + return message, stats + except Exception as e: + logger.error(f"Preprocessing failed: {e}", exc_info=True) + return f"Preprocessing failed: {e}", gr.update() + + def delete_video_from_dataset( + self, dataset_name: str, video_name: str + ) -> tuple[str, gr.Gallery, None, str, str, dict]: + """Delete a video from a dataset. + + Args: + dataset_name: Name of the dataset + video_name: Name of the video to delete + + Returns: + Tuple of (status message, updated gallery, cleared video/name/caption, updated stats) + """ + if not dataset_name or not video_name: + return "No video selected", gr.update(), None, "", "", gr.update() + + try: + self.dataset_manager.delete_video(dataset_name, video_name) + + # Refresh gallery + items = self.dataset_manager.get_dataset_items(dataset_name) + gallery_items = [ + ( + i["thumbnail"], + ( + i.get("caption", "")[:50] + "..." + if len(i.get("caption", "")) > 50 + else i.get("caption", "") or "No caption" + ), + ) + for i in items + if i["thumbnail"] + ] + stats = self.dataset_manager.get_dataset_stats(dataset_name) + + return f"Deleted {video_name}", gr.update(value=gallery_items), None, "", "", stats + except Exception as e: + return f"Error: {e}", gr.update(), None, "", "", gr.update() + + def split_scenes_and_add( + self, + dataset_name: str, + video_file: str, + detector_type: str, + min_scene_length: int | None, + threshold: float | None, + filter_shorter: str | None, + max_scenes: int | None, + ) -> tuple[str, gr.Gallery, dict]: + """Split a video into scenes and add them to the dataset. + + Args: + dataset_name: Name of the dataset + video_file: Path to the video file + detector_type: Type of scene detector to use + min_scene_length: Minimum scene length in frames + threshold: Detection threshold + filter_shorter: Filter scenes shorter than this duration + max_scenes: Maximum number of scenes (0 for unlimited) + + Returns: + Tuple of (status message, updated gallery, updated stats) + """ + if not dataset_name: + return "Please select a dataset first", gr.update(), gr.update() + + if not video_file: + return "Please upload a video file", gr.update(), gr.update() + + try: + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + video_path = Path(video_file) + + scene_files = self.dataset_manager.split_video_scenes( + video_path=video_path, + output_dir=temp_path, + detector_type=detector_type, + min_scene_length=int(min_scene_length) if min_scene_length else None, + threshold=float(threshold) if threshold else None, + filter_shorter_than=filter_shorter if filter_shorter and filter_shorter.strip() else None, + max_scenes=int(max_scenes) if max_scenes and max_scenes > 0 else None, + ) + + if not scene_files: + return "No scenes detected", gr.update(), gr.update() + + result = self.dataset_manager.add_videos(dataset_name, [str(f) for f in scene_files]) + + message = ( + f"Split into {len(scene_files)} scenes using {detector_type} detector. " + f"Added {len(result['added'])} scenes to dataset." + ) + if result["failed"]: + message += f" {len(result['failed'])} failed." + + items = self.dataset_manager.get_dataset_items(dataset_name) + gallery_items = [ + ( + i["thumbnail"], + ( + i.get("caption", "")[:50] + "..." + if len(i.get("caption", "")) > 50 + else i.get("caption", "") or "No caption" + ), + ) + for i in items + if i["thumbnail"] + ] + stats = self.dataset_manager.get_dataset_stats(dataset_name) + + return message, gr.update(value=gallery_items), stats + + except Exception as e: + logger.error(f"Scene splitting failed: {e}", exc_info=True) + return f"Error: {e}", gr.update(), gr.update() + + def filter_dataset_gallery(self, dataset_name: str, search_text: str) -> list: + """Filter gallery by caption search. + + Args: + dataset_name: Name of the dataset + search_text: Text to search for in captions + + Returns: + Filtered gallery items + """ + if not dataset_name: + return [] + + items = self.dataset_manager.get_dataset_items(dataset_name) + + if search_text and search_text.strip(): + search_lower = search_text.lower() + items = [i for i in items if search_lower in i.get("caption", "").lower()] + + gallery_items = [ + ( + i["thumbnail"], + ( + i.get("caption", "")[:50] + "..." + if len(i.get("caption", "")) > 50 + else i.get("caption", "") or "No caption" + ), + ) + for i in items + if i["thumbnail"] + ] + + return gallery_items + + def refresh_job_list(self, show_all: bool = False) -> list[list]: + """Get formatted list of jobs for display. + + Args: + show_all: If True, show all jobs. If False, show only latest job. + + Returns: + List of job rows for dataframe display + """ + jobs = self.job_db.get_all_jobs() + + # If not showing all, only show the most recent job + if not show_all and jobs: + jobs = [jobs[0]] # Jobs are already ordered by ID desc (most recent first) + + # Format for dataframe + rows = [] + for job in jobs: + rows.append( + [ + job["id"], + job["status"], + job["dataset_name"], + job["created_at"][:19] if job.get("created_at") else "", + job.get("progress", "")[:50] if job.get("progress") else "", + ] + ) + + return rows + + def get_latest_job_id(self) -> int: + """Get the ID of the latest job. + + Returns: + Latest job ID or 0 if no jobs + """ + jobs = self.job_db.get_all_jobs() + return jobs[0]["id"] if jobs else 0 + + def get_running_job(self) -> dict | None: + """Get the currently running job if any. + + Returns: + Running job dict or None + """ + jobs = self.job_db.get_all_jobs() + for job in jobs: + if job["status"] == JobStatus.RUNNING: + return job + return None + + def get_current_job_display(self) -> tuple[str, str, str | None, str, str]: + """Get formatted display for the current running job. + + Returns: + Tuple of (status_html, job_info, validation_sample, validation_prompt, logs) + """ + running_job = self.get_running_job() + + if not running_job: + return ( + '
No job currently running
', + "", + None, + "", + "", + ) + + # Format job info + job_info = f"""**Job #{running_job["id"]}** - {running_job["dataset_name"]} +**Status:** {running_job["status"]} +**Progress:** {running_job.get("progress", "Starting...")} +**Started:** {running_job.get("started_at", "N/A")[:19] if running_job.get("started_at") else "N/A"} +""" + + # Get validation sample + validation_sample = running_job.get("validation_sample") + if validation_sample and not Path(validation_sample).is_absolute(): + validation_sample = str(PROJECT_ROOT / validation_sample) + + # Get validation prompt from job params + validation_prompt = running_job.get("params", {}).get("validation_prompt", "") + + # Get logs + logs = running_job.get("logs", "") or "Waiting for logs..." + + # Status HTML with color coding + status_color = "#2563eb" # blue for running + status_html = ( + f'
' + f'▶️ Training in Progress - Job #{running_job["id"]}' + f"
" + ) + + return status_html, job_info, validation_sample, validation_prompt, logs + + def view_job_details(self, job_id: int | float) -> tuple[int, dict, str | None, str, gr.Accordion]: + """View job details by ID. + + Args: + job_id: The job ID to view + + Returns: + Tuple of (job_id, job_details, validation_sample, logs, accordion_update) + """ + try: + if not job_id or job_id == 0: + return 0, {}, None, "Please enter a valid Job ID", gr.Accordion(open=False) + + job = self.job_db.get_job(int(job_id)) + if job: + logs = job.get("logs", "") or "No logs available yet" + validation_sample = job.get("validation_sample") + # Convert path to absolute if it's a relative path + if validation_sample and not Path(validation_sample).is_absolute(): + validation_sample = str(PROJECT_ROOT / validation_sample) + return int(job_id), job, validation_sample, logs, gr.Accordion(open=True) + else: + return int(job_id), {}, None, f"Job #{int(job_id)} not found", gr.Accordion(open=True) + except (TypeError, ValueError) as e: + logger.error(f"Error viewing job details: {e}") + return 0, {}, None, f"Error: {e}", gr.Accordion(open=False) + + def clear_completed_jobs(self) -> tuple[str, list[list]]: + """Clear all completed, failed, and cancelled jobs. + + Returns: + Tuple of (status message, updated job list) + """ + self.job_db.clear_completed_jobs() + return "✅ Cleared all completed jobs", self.refresh_job_list(show_all=True) + + def stop_worker(self) -> tuple[str, list[list]]: + """Stop the worker and cancel the current running job. + + Returns: + Tuple of (status message, updated job list) + """ + try: + # First, cancel any running jobs + jobs = self.job_db.get_all_jobs() + cancelled_jobs = [] + for job in jobs: + if job["status"] == JobStatus.RUNNING: + self.job_db.update_job_status( + job["id"], JobStatus.CANCELLED, error_message="Worker stopped by user" + ) + cancelled_jobs.append(job["id"]) + + # Send shutdown signal to worker + db_path = PROJECT_ROOT / "jobs.db" + shutdown_signal_file = db_path.parent / ".worker_shutdown_signal" + shutdown_signal_file.touch() + + if cancelled_jobs: + msg = f"✅ Cancelled job(s) #{', #'.join(map(str, cancelled_jobs))} and sent shutdown signal to worker." + else: + msg = "✅ Shutdown signal sent to worker. No running jobs to cancel." + + return msg, self.refresh_job_list(show_all=True) + except Exception as e: + logger.error(f"Error stopping worker: {e}") + return f"❌ Error: {e}", self.refresh_job_list(show_all=True) + + def start_worker(self) -> str: + """Start the worker process. + + Returns: + Status message + """ + try: + import subprocess + import sys + + # Check if already running + status = self.check_worker_status() + if "running" in status.lower(): + return f"⚠️ Worker is already running\n{status}" + + # Remove shutdown signal if it exists + shutdown_signal_file = PROJECT_ROOT / ".worker_shutdown_signal" + if shutdown_signal_file.exists(): + shutdown_signal_file.unlink() + + # Start worker process + worker_script = PROJECT_ROOT / "scripts" / "jobs" / "run_worker.py" + python_exe = sys.executable + + subprocess.Popen( + [python_exe, str(worker_script)], + cwd=str(PROJECT_ROOT), + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + start_new_session=True, # Detach from parent + ) + + return "✅ Worker started successfully" + except Exception as e: + logger.error(f"Error starting worker: {e}") + return f"❌ Error starting worker: {e}" + + def check_worker_status(self) -> str: + """Check if worker is running. + + Returns: + Worker status message + """ + try: + import psutil + + # Look for worker process + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + try: + cmdline = proc.info.get("cmdline", []) + if cmdline and "run_worker.py" in " ".join(cmdline): + return f"✅ Worker is running (PID: {proc.info['pid']})" + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + return "⚠️ Worker is not running. Start it with: python scripts/jobs/run_worker.py" + except ImportError: + return "💡 Install psutil to check worker status: pip install psutil" + except Exception as e: + return f"❓ Could not determine worker status: {e}" + + def create_ui(self) -> gr.Blocks: # noqa: PLR0915 + """Create the Gradio UI.""" + with gr.Blocks() as blocks: + gr.Markdown("# LTX-Video Trainer") + + with gr.Tab("Training"): + gr.Markdown("# 🎬 Train LTX-Video LoRA") + + # Dataset Selection + with gr.Group(): + gr.Markdown("### 📁 Dataset") + dataset_for_training = gr.Dropdown( + choices=self.dataset_manager.list_datasets(), + label="Select Dataset", + interactive=True, + info="Choose a dataset you created in the Datasets tab", + ) + + # Training Configuration + with gr.Row(): + with gr.Column(scale=1): + with gr.Group(): + gr.Markdown("### ⚙️ Training Parameters") + model_source = gr.Dropdown( + choices=[str(v) for v in LtxvModelVersion], + value=str(LtxvModelVersion.latest()), + label="Model Version", + info="Select the model version to use for training", + ) + + with gr.Row(): + lr = gr.Number(value=2e-4, label="Learning Rate", scale=1) + steps = gr.Number( + value=1500, + label="Training Steps", + precision=0, + info="Total training steps", + scale=1, + ) + + with gr.Row(): + lora_rank = gr.Dropdown( + choices=list(range(8, 257, 8)), + value=128, + label="LoRA Rank", + info="Higher = more capacity", + scale=1, + ) + batch_size = gr.Number(value=1, label="Batch Size", precision=0, scale=1) + + validation_interval = gr.Number( + value=100, + label="Validation Interval", + precision=0, + info="Steps between validation samples", + minimum=1, + ) + + with gr.Group(): + gr.Markdown("### 🎞️ Video Settings") + id_token = gr.Textbox( + label="LoRA ID Token", + placeholder="Optional: e.g., or ", + value="", + info="Prepended to training captions", + ) + + with gr.Row(): + width = gr.Dropdown( + choices=list(range(256, 1025, 32)), + value=768, + label="Width", + info="Multiple of 32", + ) + height = gr.Dropdown( + choices=list(range(256, 1025, 32)), + value=768, + label="Height", + info="Multiple of 32", + ) + num_frames = gr.Dropdown( + choices=list(range(9, 129, 8)), + value=25, + label="Frames", + info="Multiple of 8", + ) + + with gr.Column(scale=1), gr.Group(): + # Training Monitor + gr.Markdown("### 📊 Training Monitor") + + with gr.Row(): + train_btn = gr.Button("▶️ Start Training", variant="primary", size="lg", scale=2) + reset_btn = gr.Button("🔄 Reset", variant="secondary", size="lg", scale=1) + + current_job_display = gr.Markdown(value="", visible=True, label="Current Job") + + self.status_output = gr.Textbox(label="Status", interactive=False) + self.progress_output = gr.Textbox(label="Progress", interactive=False) + + gr.Markdown("**Validation Sample**") + self.validation_prompt = gr.Textbox( + label="Validation Prompt", + placeholder="Enter the prompt to use for validation samples", + value="a professional portrait video of a person with blurry bokeh background", + interactive=True, + info="Prompt used to generate sample videos during training (view in Jobs tab)", + lines=2, + ) + + # Advanced Settings (Collapsible) + with gr.Accordion("🚀 HuggingFace Hub (Optional)", open=False): + push_to_hub = gr.Checkbox( + label="Push to HuggingFace Hub", + value=False, + info="Automatically upload trained model to HuggingFace Hub", + ) + with gr.Row(): + hf_token = gr.Textbox( + label="HuggingFace Token", + type="password", + placeholder="hf_...", + info="Your HuggingFace API token", + ) + hf_model_id = gr.Textbox( + label="Model ID", + placeholder="username/model-name", + info="Repository name on HuggingFace", + ) + + # Results Section + with gr.Group(): + gr.Markdown("### ✅ Training Results") + with gr.Row(): + self.download_btn = gr.DownloadButton( + label="📥 Download LoRA Weights", + visible=False, + interactive=True, + size="lg", + ) + self.hf_repo_link = gr.HTML( + value="", + visible=False, + label="HuggingFace Hub", + ) + + # Jobs Queue Link + gr.Markdown("---") # Separator + gr.Markdown("### 📋 Training Queue Status") + gr.Markdown( + "💡 **Tip:** Switch to the **Jobs** tab to monitor training progress, " + "view logs, and manage the queue." + ) + + # Jobs Tab + with gr.Tab("Jobs"): + gr.Markdown("# 📋 Training Jobs") + + # Worker Status Section + with gr.Group(): + gr.Markdown("## 🔧 Worker Status") + with gr.Row(): + check_worker_btn = gr.Button("🔍 Check Status", size="sm", scale=1) + stop_worker_btn = gr.Button("🛑 Stop Worker", variant="stop", size="sm", scale=1) + + queue_status = gr.Textbox( + label="", + interactive=False, + show_label=False, + container=True, + placeholder="Worker status will appear here...", + ) + + gr.Markdown("---") + + # Current Running Job Section + with gr.Group(): + gr.Markdown("## ▶️ Current Training Job") + + current_job_status = gr.HTML( + value=( + '
' + "No job currently running
" + ) + ) + + with gr.Row(): + with gr.Column(scale=2): + current_job_info = gr.Markdown(value="") + + gr.Markdown("#### 🎬 Latest Validation Sample") + current_job_sample = gr.Video( + label="", + show_label=False, + interactive=False, + autoplay=True, + loop=True, + height=300, + ) + + gr.Markdown("#### 💬 Validation Prompt") + current_job_prompt = gr.Textbox( + label="", + show_label=False, + interactive=False, + lines=2, + placeholder="Validation prompt will appear here...", + ) + + with gr.Column(scale=3): + gr.Markdown("#### 📝 Training Logs (Live)") + current_job_logs = gr.Textbox( + label="", + lines=25, + max_lines=30, + interactive=False, + show_copy_button=True, + placeholder="Logs will appear here when training starts...", + autoscroll=True, + ) + + gr.Markdown("---") + + # Job History Section + with gr.Group(): + gr.Markdown("## 📚 Job History") + + with gr.Row(): + refresh_jobs_btn = gr.Button("🔄 Refresh Jobs", size="sm") + clear_completed_btn = gr.Button("🗑️ Clear Completed", size="sm", variant="secondary") + + job_table = gr.Dataframe( + headers=["ID", "Status", "Dataset", "Created", "Progress"], + datatype=["number", "str", "str", "str", "str"], + label="All Jobs", + interactive=False, + wrap=True, + row_count=10, + ) + + # Manual job selection for viewing older jobs + with gr.Accordion("🔍 View Specific Job", open=False): + with gr.Row(): + selected_job_id = gr.Number( + label="Job ID", + precision=0, + value=0, + interactive=True, + scale=1, + ) + view_job_btn = gr.Button("View", size="sm", scale=1) + + selected_job_details = gr.JSON(label="Job Configuration", value={}) + + gr.Markdown("#### Validation Sample") + selected_job_sample = gr.Video( + label="", + show_label=False, + interactive=False, + autoplay=True, + loop=True, + height=250, + ) + + gr.Markdown("#### Logs") + selected_job_logs = gr.Textbox( + label="", + lines=15, + interactive=False, + show_copy_button=True, + ) + + with gr.Accordion("💡 Worker Information", open=False): + gr.Markdown( + "### Worker Management\n\n" + "The worker processes training jobs from the queue.\n\n" + "**To start:** Click the '▶️ Start Worker' button above, or run:\n" + "```bash\n" + "python scripts/jobs/run_worker.py\n" + "```\n\n" + "**To stop:** Click the '🛑 Stop Worker' button above.\n\n" + "**To check status:** Click the '🔍 Check Worker Status' button." + ) + + gr.Textbox( + label="Worker Status", + interactive=False, + placeholder="Click 'Check Worker Status' to see if worker is running...", + ) + + # Event handlers + + # Auto-refresh current job on timer + def auto_refresh_current_job() -> tuple[str, str, str | None, str, str, list, str]: + """Auto-refresh the current running job display and check worker status.""" + status_html, job_info, validation_sample, validation_prompt, logs = self.get_current_job_display() + job_list = self.refresh_job_list(show_all=True) + worker_status = self.check_worker_status() + return status_html, job_info, validation_sample, validation_prompt, logs, job_list, worker_status + + job_refresh_timer = gr.Timer(value=5) + job_refresh_timer.tick( + fn=auto_refresh_current_job, + outputs=[ + current_job_status, + current_job_info, + current_job_sample, + current_job_prompt, + current_job_logs, + job_table, + queue_status, + ], + show_progress=False, + ) + + # Manual refresh + refresh_jobs_btn.click( + auto_refresh_current_job, + outputs=[ + current_job_status, + current_job_info, + current_job_sample, + current_job_prompt, + current_job_logs, + job_table, + queue_status, + ], + ) + + # Clear completed jobs + clear_completed_btn.click( + self.clear_completed_jobs, + outputs=[queue_status, job_table], + ) + + # Worker control + check_worker_btn.click( + self.check_worker_status, + outputs=[queue_status], + ) + + stop_worker_btn.click( + self.stop_worker, + outputs=[queue_status, job_table], + ) + + # View specific job + def view_specific_job(job_id: int | float) -> tuple[dict, str | None, str]: + if not job_id or job_id == 0: + return {}, None, "Please enter a valid Job ID" + job_id_int, details, sample, logs, _ = self.view_job_details(job_id) + return details, sample, logs + + view_job_btn.click( + view_specific_job, + inputs=[selected_job_id], + outputs=[selected_job_details, selected_job_sample, selected_job_logs], + ) + + selected_job_id.submit( + view_specific_job, + inputs=[selected_job_id], + outputs=[selected_job_details, selected_job_sample, selected_job_logs], + ) + + # Datasets Tab + with gr.Tab("Datasets"), gr.Row(): + with gr.Column(scale=1): + # Dataset selection/creation + gr.Markdown("## Manage Datasets") + + dataset_dropdown = gr.Dropdown( + choices=self.dataset_manager.list_datasets(), + label="Select Dataset", + interactive=True, + ) + + with gr.Row(): + new_dataset_name = gr.Textbox(label="New Dataset Name", placeholder="my_dataset") + create_dataset_btn = gr.Button("Create Dataset", variant="primary") + + # Dataset statistics + stats_box = gr.JSON(label="Dataset Statistics", value={}) + + # Captioning settings + gr.Markdown("### Caption Settings") + + with gr.Accordion("🎨 VLM Captioning Options", open=False): + vlm_instruction = gr.Textbox( + label="Custom VLM Instruction", + placeholder="Leave empty to use default instruction", + value="", + lines=3, + info="Custom instruction for the vision-language model when generating captions", + ) + vlm_use_8bit = gr.Checkbox( + label="Use 8-bit Quantization", + value=True, + info="Reduces VRAM usage during captioning", + ) + + # Batch operations + gr.Markdown("### Batch Operations") + + with gr.Row(): + auto_caption_btn = gr.Button("Auto-Caption All Uncaptioned", size="sm") + validate_btn = gr.Button("Validate Dataset", size="sm") + + validation_result = gr.JSON(label="Validation Results", value={}) + + # Preprocessing options + gr.Markdown("### Preprocessing Options") + + with gr.Accordion("⚙️ Advanced Preprocessing", open=False): + decode_videos_check = gr.Checkbox( + label="Decode Videos for Verification", + value=False, + info="Decode preprocessed latents to verify quality (slower, uses more disk space)", + ) + load_text_encoder_8bit = gr.Checkbox( + label="Load Text Encoder in 8-bit", + value=False, + info="Reduces VRAM usage during text embedding computation", + ) + + with gr.Row(): + preprocess_btn = gr.Button("Preprocess Dataset", variant="primary") + + preprocess_status = gr.Textbox(label="Preprocessing Status", interactive=False) + + with gr.Column(scale=3): + # Video upload area + gr.Markdown("## Upload Videos") + + video_uploader = gr.File( + label="Drag and drop videos here", + file_count="multiple", + file_types=["video"], + type="filepath", + ) + + upload_btn = gr.Button("Add to Dataset") + upload_status = gr.Textbox(label="Upload Status", interactive=False) + + # IC-LoRA reference videos + with gr.Accordion("🎯 IC-LoRA: Add Reference Videos (Optional)", open=False): + gr.Markdown( + "**IC-LoRA Training:** Upload reference videos paired with target videos " + "for advanced video-to-video transformations. " + "Reference videos must match target videos in count, resolution, and length." + ) + reference_uploader = gr.File( + label="Reference Videos (for IC-LoRA) - Upload in same order as target videos", + file_count="multiple", + file_types=["video"], + type="filepath", + ) + upload_with_ref_btn = gr.Button("Add Videos with References", variant="secondary") + + # Scene splitting + with gr.Accordion("✂️ Split Long Videos into Scenes", open=False): + gr.Markdown( + "Upload a long video and automatically split it into individual scenes " + "using advanced detection algorithms" + ) + scene_video_uploader = gr.File( + label="Video to Split", + file_count="single", + file_types=["video"], + type="filepath", + ) + with gr.Row(): + scene_detector_type = gr.Dropdown( + choices=["content", "adaptive", "threshold", "histogram"], + value="content", + label="Detection Algorithm", + info=( + "Content: HSV color changes | Adaptive: Two-phase cuts | " + "Threshold: Fades | Histogram: YUV changes" + ), + ) + scene_threshold = gr.Number( + label="Detection Threshold", + value=27.0, + info="Lower = more sensitive", + ) + with gr.Row(): + scene_min_length = gr.Number( + label="Min Scene Length (frames)", + value=30, + precision=0, + info="Minimum scene duration during detection", + ) + scene_filter_shorter = gr.Textbox( + label="Filter Shorter Than", + value="2s", + info="Filter scenes shorter than duration (e.g., '2s', '60' frames)", + ) + scene_max_scenes = gr.Number( + label="Max Scenes (0=unlimited)", + value=0, + precision=0, + info="Maximum number of scenes to produce", + ) + split_scenes_btn = gr.Button("Split Scenes and Add to Dataset", variant="primary") + + # Visual dataset browser + gr.Markdown("## Dataset Browser") + + with gr.Row(): + search_box = gr.Textbox( + label="Search captions", placeholder="Filter by caption text...", scale=4 + ) + refresh_btn = gr.Button("🔄 Refresh", size="sm", scale=1) + + # Video gallery with captions + dataset_gallery = gr.Gallery( + label="Videos", + columns=4, + height="auto", + object_fit="contain", + allow_preview=True, + ) + + # Selected video editor + gr.Markdown("### Edit Selected Video") + + with gr.Row(): + selected_video = gr.Video(label="Preview", interactive=False) + + with gr.Column(): + selected_video_name = gr.Textbox(label="Video Name", interactive=False) + + caption_editor = gr.Textbox( + label="Caption", + placeholder="Enter caption for this video...", + lines=3, + interactive=True, + ) + + with gr.Row(): + save_caption_btn = gr.Button("Save Caption", variant="primary") + generate_caption_btn = gr.Button("Auto-Generate") + delete_video_btn = gr.Button("Delete Video", variant="stop") + + # Event handlers + # Update HF fields visibility based on push_to_hub checkbox + push_to_hub.change( + lambda x: { + hf_token: gr.update(visible=x), + hf_model_id: gr.update(visible=x), + }, + inputs=[push_to_hub], + outputs=[hf_token, hf_model_id], + ) + + train_btn.click( + lambda dataset_name, + validation_prompt, + lr, + steps, + lora_rank, + batch_size, + model_source, + width, + height, + num_frames, + push_to_hub, + hf_model_id, + hf_token, + id_token, + validation_interval: self.start_training( + TrainingParams( + dataset_name=dataset_name, + validation_prompt=validation_prompt, + learning_rate=lr, + steps=steps, + lora_rank=lora_rank, + batch_size=batch_size, + model_source=model_source, + width=width, + height=height, + num_frames=num_frames, + push_to_hub=push_to_hub, + hf_model_id=hf_model_id, + hf_token=hf_token, + id_token=id_token, + validation_interval=validation_interval, + ) + ), + inputs=[ + dataset_for_training, + self.validation_prompt, + lr, + steps, + lora_rank, + batch_size, + model_source, + width, + height, + num_frames, + push_to_hub, + hf_model_id, + hf_token, + id_token, + validation_interval, + ], + outputs=[self.status_output, train_btn], + ) + + # Update timer to use class method + timer = gr.Timer(value=10) # 10 second interval + timer.tick( + fn=self.update_progress, + inputs=None, + outputs=[ + current_job_display, + self.status_output, + self.download_btn, + self.hf_repo_link, + self.hf_repo_link, + ], + show_progress=False, + ) + + # Handle download button click + self.download_btn.click(self.get_model_path, inputs=None, outputs=[self.download_btn]) + + # Handle reset button click + reset_btn.click( + self.reset_interface, + inputs=None, + outputs=[ + self.validation_prompt, + self.status_output, + self.progress_output, + self.download_btn, + self.hf_repo_link, + ], + ) + + # Dataset event handlers + create_dataset_btn.click( + self.create_new_dataset, + inputs=[new_dataset_name], + outputs=[ + upload_status, + dataset_dropdown, + dataset_gallery, + stats_box, + selected_video, + selected_video_name, + caption_editor, + dataset_for_training, # Update training tab dropdown + ], + ) + + dataset_dropdown.change( + self.load_dataset, + inputs=[dataset_dropdown], + outputs=[ + dataset_dropdown, + dataset_gallery, + stats_box, + selected_video, + selected_video_name, + caption_editor, + ], + ) + + upload_btn.click( + self.upload_videos_to_dataset, + inputs=[video_uploader, dataset_dropdown], + outputs=[upload_status, dataset_gallery, stats_box], + ) + + upload_with_ref_btn.click( + self.upload_videos_with_references, + inputs=[video_uploader, reference_uploader, dataset_dropdown], + outputs=[upload_status, dataset_gallery, stats_box], + ) + + split_scenes_btn.click( + self.split_scenes_and_add, + inputs=[ + dataset_dropdown, + scene_video_uploader, + scene_detector_type, + scene_min_length, + scene_threshold, + scene_filter_shorter, + scene_max_scenes, + ], + outputs=[upload_status, dataset_gallery, stats_box], + ) + + dataset_gallery.select( + self.select_video_from_gallery, + inputs=[dataset_dropdown], + outputs=[selected_video, selected_video_name, caption_editor], + ) + + save_caption_btn.click( + self.save_caption_edit, + inputs=[dataset_dropdown, selected_video_name, caption_editor], + outputs=[upload_status, stats_box], + ) + + generate_caption_btn.click( + self.auto_caption_single, + inputs=[dataset_dropdown, selected_video_name, vlm_instruction, vlm_use_8bit], + outputs=[caption_editor, upload_status], + ) + + auto_caption_btn.click( + self.auto_caption_all_uncaptioned, + inputs=[dataset_dropdown, vlm_instruction, vlm_use_8bit], + outputs=[preprocess_status, stats_box], + ) + + validate_btn.click(self.validate_dataset_ui, inputs=[dataset_dropdown], outputs=[validation_result]) + + preprocess_btn.click( + self.preprocess_dataset_ui, + inputs=[ + dataset_dropdown, + width, + height, + num_frames, + id_token, + decode_videos_check, + load_text_encoder_8bit, + ], + outputs=[preprocess_status, stats_box], + ) + + refresh_btn.click( + self.load_dataset, + inputs=[dataset_dropdown], + outputs=[ + dataset_dropdown, + dataset_gallery, + stats_box, + selected_video, + selected_video_name, + caption_editor, + ], + ) + + delete_video_btn.click( + self.delete_video_from_dataset, + inputs=[dataset_dropdown, selected_video_name], + outputs=[ + upload_status, + dataset_gallery, + selected_video, + selected_video_name, + caption_editor, + stats_box, + ], + ) + + search_box.change( + self.filter_dataset_gallery, inputs=[dataset_dropdown, search_box], outputs=[dataset_gallery] + ) + + return blocks + + +def main() -> None: + """Main entry point.""" + ui = GradioUI() + demo = ui.create_ui() + demo.queue() + demo.launch(server_name="0.0.0.0", server_port=7860) + + +if __name__ == "__main__": + main() diff --git a/scripts/dataset_manager.py b/scripts/dataset_manager.py new file mode 100644 index 0000000..da45d17 --- /dev/null +++ b/scripts/dataset_manager.py @@ -0,0 +1,400 @@ +"""Dataset Manager for LTX-Video-Trainer. + +This module provides utilities for managing video datasets, including: +- Creating and organizing datasets +- Adding videos with automatic thumbnail generation +- Managing captions +- Dataset validation +""" + +import json +import shutil +from pathlib import Path +from typing import Optional + +import cv2 + + +class DatasetManager: + """Manager for video dataset operations.""" + + def __init__(self, datasets_root: Path = Path(__file__).parent.parent / "datasets"): + """Initialize the dataset manager. + + Args: + datasets_root: Root directory for all datasets + """ + self.datasets_root = datasets_root + self.datasets_root.mkdir(exist_ok=True) + + def list_datasets(self) -> list[str]: + """Get all dataset names. + + Returns: + List of dataset names + """ + return [d.name for d in self.datasets_root.iterdir() if d.is_dir()] + + def create_dataset(self, name: str) -> Path: + """Create new dataset directory structure. + + Args: + name: Name of the dataset + + Returns: + Path to the created dataset directory + + Raises: + ValueError: If dataset name is invalid or already exists + """ + if not name or not name.strip(): + raise ValueError("Dataset name cannot be empty") + + safe_name = "".join(c for c in name if c.isalnum() or c in ("-", "_")).strip() + if not safe_name: + raise ValueError("Dataset name must contain alphanumeric characters") + + dataset_dir = self.datasets_root / safe_name + + if dataset_dir.exists(): + raise ValueError(f"Dataset '{safe_name}' already exists") + + (dataset_dir / "videos").mkdir(parents=True, exist_ok=True) + (dataset_dir / "thumbnails").mkdir(parents=True, exist_ok=True) + + dataset_json = dataset_dir / "dataset.json" + with open(dataset_json, "w") as f: + json.dump([], f) + + return dataset_dir + + def add_videos( + self, dataset_name: str, video_files: list[str], reference_files: Optional[list[str]] = None + ) -> dict: + """Add videos to dataset and generate thumbnails. + + Args: + dataset_name: Name of the dataset + video_files: List of video file paths to add + reference_files: Optional list of reference video files for IC-LoRA (must match video_files length) + + Returns: + Dictionary with 'added' and 'failed' lists + """ + dataset_dir = self.datasets_root / dataset_name + if not dataset_dir.exists(): + raise ValueError(f"Dataset '{dataset_name}' does not exist") + + videos_dir = dataset_dir / "videos" + thumbs_dir = dataset_dir / "thumbnails" + + if reference_files: + references_dir = dataset_dir / "references" + references_dir.mkdir(exist_ok=True) + + if len(reference_files) != len(video_files): + raise ValueError("Number of reference videos must match number of videos") + + results = {"added": [], "failed": []} + + dataset_json = dataset_dir / "dataset.json" + with open(dataset_json) as f: + items = json.load(f) + + for idx, video_file in enumerate(video_files): + try: + video_path = Path(video_file) + dest_path = videos_dir / video_path.name + + if dest_path.exists(): + results["failed"].append({"video": video_path.name, "error": "Already exists"}) + continue + + shutil.copy2(video_file, dest_path) + + thumb_path = self._generate_thumbnail(dest_path, thumbs_dir) + + entry = {"media_path": f"videos/{video_path.name}", "caption": ""} + + if reference_files and idx < len(reference_files): + ref_path = Path(reference_files[idx]) + ref_dest_path = references_dir / ref_path.name + shutil.copy2(reference_files[idx], ref_dest_path) + entry["reference_path"] = f"references/{ref_path.name}" + + items.append(entry) + + results["added"].append({"video": video_path.name, "thumbnail": thumb_path.name}) + + except Exception as e: + results["failed"].append({"video": Path(video_file).name, "error": str(e)}) + + with open(dataset_json, "w") as f: + json.dump(items, f, indent=2) + + return results + + def _generate_thumbnail(self, video_path: Path, output_dir: Path) -> Path: + """Extract first frame as thumbnail. + + Args: + video_path: Path to the video file + output_dir: Directory to save thumbnail + + Returns: + Path to the generated thumbnail + + Raises: + ValueError: If video cannot be read + """ + cap = cv2.VideoCapture(str(video_path)) + + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + if frame_count > 10: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count // 2) + + ret, frame = cap.read() + cap.release() + + if not ret: + raise ValueError(f"Could not read video: {video_path}") + + thumb_path = output_dir / f"{video_path.stem}.jpg" + + height, width = frame.shape[:2] + target_size = 256 + + if width > height: + new_width = target_size + new_height = int(height * (target_size / width)) + else: + new_height = target_size + new_width = int(width * (target_size / height)) + + frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA) + + if new_width != new_height: + max_dim = max(new_width, new_height) + square_frame = cv2.copyMakeBorder( + frame, + (max_dim - new_height) // 2, + (max_dim - new_height + 1) // 2, + (max_dim - new_width) // 2, + (max_dim - new_width + 1) // 2, + cv2.BORDER_CONSTANT, + value=(0, 0, 0), + ) + frame = square_frame + + cv2.imwrite(str(thumb_path), frame) + return thumb_path + + def get_dataset_items(self, dataset_name: str) -> list[dict]: + """Get all videos and captions in dataset. + + Args: + dataset_name: Name of the dataset + + Returns: + List of dictionaries with 'media_path', 'caption', and 'thumbnail' + """ + dataset_dir = self.datasets_root / dataset_name + if not dataset_dir.exists(): + return [] + + dataset_json = dataset_dir / "dataset.json" + + if dataset_json.exists(): + with open(dataset_json) as f: + items = json.load(f) + else: + items = [] + + for item in items: + video_name = Path(item["media_path"]).name + thumb_path = dataset_dir / "thumbnails" / f"{Path(video_name).stem}.jpg" + item["thumbnail"] = str(thumb_path) if thumb_path.exists() else None + item["full_video_path"] = str(dataset_dir / item["media_path"]) + + return items + + def update_caption(self, dataset_name: str, video_name: str, caption: str) -> None: + """Update caption for a specific video. + + Args: + dataset_name: Name of the dataset + video_name: Name of the video file + caption: New caption text + """ + dataset_dir = self.datasets_root / dataset_name + dataset_json = dataset_dir / "dataset.json" + + with open(dataset_json) as f: + items = json.load(f) + + found = False + for item in items: + if Path(item["media_path"]).name == video_name: + item["caption"] = caption + found = True + break + + if not found: + items.append({"media_path": f"videos/{video_name}", "caption": caption}) + + with open(dataset_json, "w") as f: + json.dump(items, f, indent=2) + + def delete_video(self, dataset_name: str, video_name: str) -> None: + """Remove video from dataset. + + Args: + dataset_name: Name of the dataset + video_name: Name of the video file to delete + """ + dataset_dir = self.datasets_root / dataset_name + + dataset_json = dataset_dir / "dataset.json" + with open(dataset_json) as f: + items = json.load(f) + + items = [i for i in items if Path(i["media_path"]).name != video_name] + + with open(dataset_json, "w") as f: + json.dump(items, f, indent=2) + + (dataset_dir / "videos" / video_name).unlink(missing_ok=True) + (dataset_dir / "thumbnails" / f"{Path(video_name).stem}.jpg").unlink(missing_ok=True) + + def get_dataset_stats(self, dataset_name: str) -> dict: + """Get statistics about dataset. + + Args: + dataset_name: Name of the dataset + + Returns: + Dictionary with dataset statistics + """ + dataset_dir = self.datasets_root / dataset_name + items = self.get_dataset_items(dataset_name) + + has_references = any(i.get("reference_path") for i in items) + precomputed_dir = dataset_dir / ".precomputed" + has_reference_latents = (precomputed_dir / "reference_latents").exists() if precomputed_dir.exists() else False + + return { + "name": dataset_name, + "total_videos": len(items), + "captioned": sum(1 for i in items if i.get("caption") and i.get("caption").strip()), + "uncaptioned": sum(1 for i in items if not i.get("caption") or not i.get("caption").strip()), + "preprocessed": precomputed_dir.exists(), + "has_references": has_references, + "reference_latents_computed": has_reference_latents, + } + + def validate_dataset(self, dataset_name: str) -> dict: + """Validate dataset and return issues. + + Args: + dataset_name: Name of the dataset + + Returns: + Dictionary with validation results + """ + issues = [] + warnings = [] + + items = self.get_dataset_items(dataset_name) + dataset_dir = self.datasets_root / dataset_name + + if not items: + issues.append("Dataset is empty") + return {"valid": False, "issues": issues, "warnings": warnings, "total_videos": 0} + + uncaptioned = [i for i in items if not i.get("caption") or not i.get("caption").strip()] + if uncaptioned: + warnings.append(f"{len(uncaptioned)} videos without captions") + + for item in items: + video_path = dataset_dir / item["media_path"] + if not video_path.exists(): + issues.append(f"Missing video: {item['media_path']}") + + if not (dataset_dir / ".precomputed").exists(): + warnings.append("Dataset not preprocessed - will be slower to train") + + for item in items[:5]: + video_path = dataset_dir / item["media_path"] + if video_path.exists(): + cap = cv2.VideoCapture(str(video_path)) + if not cap.isOpened(): + issues.append(f"Cannot read video: {video_path.name}") + cap.release() + + return { + "valid": len(issues) == 0, + "issues": issues, + "warnings": warnings, + "total_videos": len(items), + } + + def split_video_scenes( + self, + video_path: Path, + output_dir: Path, + min_scene_length: Optional[int] = None, + threshold: Optional[float] = None, + detector_type: str = "content", + max_scenes: Optional[int] = None, + filter_shorter_than: Optional[str] = None, + ) -> list[Path]: + """Split a video into scenes using advanced scene detection. + + Args: + video_path: Path to the video file + output_dir: Directory to save split scenes + min_scene_length: Minimum scene length in frames + threshold: Detection threshold + detector_type: Type of detector ('content', 'adaptive', 'threshold', 'histogram') + max_scenes: Maximum number of scenes to detect + filter_shorter_than: Filter scenes shorter than duration (e.g., "2s", "30") + + Returns: + List of paths to split scene files + """ + from scripts.split_scenes import DetectorType, detect_and_split_scenes + + output_dir.mkdir(parents=True, exist_ok=True) + + detector_map = { + "content": DetectorType.CONTENT, + "adaptive": DetectorType.ADAPTIVE, + "threshold": DetectorType.THRESHOLD, + "histogram": DetectorType.HISTOGRAM, + } + + detector = detector_map.get(detector_type.lower(), DetectorType.CONTENT) + + detect_and_split_scenes( + video_path=str(video_path), + output_dir=output_dir, + detector_type=detector, + threshold=threshold, + min_scene_len=min_scene_length, + max_scenes=max_scenes, + filter_shorter_than=filter_shorter_than, + save_images_per_scene=0, + ) + + scene_files = sorted(output_dir.glob(f"{video_path.stem}-Scene-*.mp4")) + return scene_files + + def delete_dataset(self, dataset_name: str) -> None: + """Delete entire dataset. + + Args: + dataset_name: Name of the dataset to delete + """ + dataset_dir = self.datasets_root / dataset_name + if dataset_dir.exists(): + shutil.rmtree(dataset_dir) diff --git a/scripts/jobs/__init__.py b/scripts/jobs/__init__.py new file mode 100644 index 0000000..fdb67a1 --- /dev/null +++ b/scripts/jobs/__init__.py @@ -0,0 +1,6 @@ +"""Job management system for training jobs.""" + +from scripts.jobs.database import JobDatabase, JobStatus +from scripts.jobs.worker import QueueWorker + +__all__ = ["JobDatabase", "JobStatus", "QueueWorker"] diff --git a/scripts/jobs/database.py b/scripts/jobs/database.py new file mode 100644 index 0000000..1deb9b3 --- /dev/null +++ b/scripts/jobs/database.py @@ -0,0 +1,314 @@ +"""SQLite database for persistent job queue management.""" + +import json +import sqlite3 +import threading +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any, Optional + + +class JobStatus(str, Enum): + """Job status enumeration.""" + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +class JobDatabase: + """SQLite database for managing training job queue.""" + + def __init__(self, db_path: Path): + """Initialize the job database. + + Args: + db_path: Path to the SQLite database file + """ + self.db_path = db_path + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._local = threading.local() + self._init_db() + + def _get_connection(self) -> sqlite3.Connection: + """Get a thread-local database connection.""" + if not hasattr(self._local, "conn") or self._local.conn is None: + self._local.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self._local.conn.row_factory = sqlite3.Row + return self._local.conn + + def _init_db(self) -> None: + """Initialize database schema.""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE IF NOT EXISTS jobs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + status TEXT NOT NULL, + dataset_name TEXT NOT NULL, + params TEXT NOT NULL, + progress TEXT, + error_message TEXT, + logs TEXT, + created_at TEXT NOT NULL, + started_at TEXT, + completed_at TEXT, + process_id INTEGER, + output_dir TEXT, + checkpoint_path TEXT, + validation_sample TEXT + ) + """ + ) + + cursor.execute( + """ + CREATE INDEX IF NOT EXISTS idx_status ON jobs(status) + """ + ) + + conn.commit() + + def create_job( + self, + dataset_name: str, + params: dict[str, Any], + ) -> int: + """Create a new training job. + + Args: + dataset_name: Name of the dataset to train on + params: Training parameters dictionary + + Returns: + Job ID + """ + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + INSERT INTO jobs ( + status, dataset_name, params, created_at + ) VALUES (?, ?, ?, ?) + """, + (JobStatus.PENDING, dataset_name, json.dumps(params), datetime.now(timezone.utc).isoformat()), + ) + + conn.commit() + return cursor.lastrowid + + def get_job(self, job_id: int) -> Optional[dict[str, Any]]: + """Get job by ID. + + Args: + job_id: Job ID + + Returns: + Job dictionary or None if not found + """ + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)) + row = cursor.fetchone() + + if row is None: + return None + + return self._row_to_dict(row) + + def get_all_jobs(self, status: Optional[JobStatus] = None) -> list[dict[str, Any]]: + """Get all jobs, optionally filtered by status. + + Args: + status: Filter by status (optional) + + Returns: + List of job dictionaries + """ + conn = self._get_connection() + cursor = conn.cursor() + + if status: + cursor.execute("SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC", (status,)) + else: + cursor.execute("SELECT * FROM jobs ORDER BY created_at DESC") + + rows = cursor.fetchall() + return [self._row_to_dict(row) for row in rows] + + def get_next_pending_job(self) -> Optional[dict[str, Any]]: + """Get the next pending job (oldest first). + + Returns: + Job dictionary or None if no pending jobs + """ + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + SELECT * FROM jobs + WHERE status = ? + ORDER BY created_at ASC + LIMIT 1 + """, + (JobStatus.PENDING,), + ) + + row = cursor.fetchone() + if row is None: + return None + + return self._row_to_dict(row) + + def update_job_status( + self, + job_id: int, + status: JobStatus, + progress: Optional[str] = None, + error_message: Optional[str] = None, + logs: Optional[str] = None, + process_id: Optional[int] = None, + output_dir: Optional[str] = None, + checkpoint_path: Optional[str] = None, + validation_sample: Optional[str] = None, + ) -> None: + """Update job status and related fields. + + Args: + job_id: Job ID + status: New status + progress: Progress message (optional) + error_message: Error message if failed (optional) + logs: Training logs (optional) + process_id: Process ID if running (optional) + output_dir: Output directory path (optional) + checkpoint_path: Path to final checkpoint (optional) + validation_sample: Path to latest validation sample (optional) + """ + conn = self._get_connection() + cursor = conn.cursor() + + updates = ["status = ?"] + values = [status] + + if progress is not None: + updates.append("progress = ?") + values.append(progress) + + if error_message is not None: + updates.append("error_message = ?") + values.append(error_message) + + if logs is not None: + updates.append("logs = ?") + values.append(logs) + + if process_id is not None: + updates.append("process_id = ?") + values.append(process_id) + + if output_dir is not None: + updates.append("output_dir = ?") + values.append(output_dir) + + if checkpoint_path is not None: + updates.append("checkpoint_path = ?") + values.append(checkpoint_path) + + if validation_sample is not None: + updates.append("validation_sample = ?") + values.append(validation_sample) + + if status == JobStatus.RUNNING and not self.get_job(job_id).get("started_at"): + updates.append("started_at = ?") + values.append(datetime.now(timezone.utc).isoformat()) + elif status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED): + updates.append("completed_at = ?") + values.append(datetime.now(timezone.utc).isoformat()) + + values.append(job_id) + + cursor.execute( + f""" + UPDATE jobs + SET {", ".join(updates)} + WHERE id = ? + """, + values, + ) + + conn.commit() + + def delete_job(self, job_id: int) -> None: + """Delete a job from the database. + + Args: + job_id: Job ID + """ + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("DELETE FROM jobs WHERE id = ?", (job_id,)) + conn.commit() + + def get_job_count(self, status: Optional[JobStatus] = None) -> int: + """Get count of jobs, optionally filtered by status. + + Args: + status: Filter by status (optional) + + Returns: + Job count + """ + conn = self._get_connection() + cursor = conn.cursor() + + if status: + cursor.execute("SELECT COUNT(*) FROM jobs WHERE status = ?", (status,)) + else: + cursor.execute("SELECT COUNT(*) FROM jobs") + + return cursor.fetchone()[0] + + def clear_completed_jobs(self) -> None: + """Clear all completed, failed, and cancelled jobs.""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + DELETE FROM jobs + WHERE status IN (?, ?, ?) + """, + (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED), + ) + + conn.commit() + + def _row_to_dict(self, row: sqlite3.Row) -> dict[str, Any]: + """Convert SQLite row to dictionary. + + Args: + row: SQLite row + + Returns: + Dictionary representation + """ + data = dict(row) + if data.get("params"): + data["params"] = json.loads(data["params"]) + return data + + def close(self) -> None: + """Close database connection.""" + if hasattr(self._local, "conn") and self._local.conn is not None: + self._local.conn.close() + self._local.conn = None diff --git a/scripts/jobs/run_worker.py b/scripts/jobs/run_worker.py new file mode 100755 index 0000000..ac216f9 --- /dev/null +++ b/scripts/jobs/run_worker.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +""" +Start the queue worker to process training jobs. + +This script starts a background worker that continuously polls the job database +and executes pending training jobs. +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from scripts.jobs.worker import main + +if __name__ == "__main__": + main() diff --git a/scripts/jobs/worker.py b/scripts/jobs/worker.py new file mode 100644 index 0000000..47d3511 --- /dev/null +++ b/scripts/jobs/worker.py @@ -0,0 +1,299 @@ +"""Queue worker that processes training jobs.""" + +import logging +import os +import signal +import subprocess +import sys +import time +from pathlib import Path +from typing import Optional + +from scripts.jobs.database import JobDatabase, JobStatus + +logger = logging.getLogger(__name__) + + +class QueueWorker: + """Worker that processes training jobs from the queue.""" + + def __init__(self, db_path: Path, check_interval: int = 5): + """Initialize the queue worker. + + Args: + db_path: Path to the job database + check_interval: Seconds between checking for new jobs + """ + self.db = JobDatabase(db_path) + self.check_interval = check_interval + self.running = False + self.current_process: Optional[subprocess.Popen] = None + self.current_job_id: Optional[int] = None + + self.shutdown_signal_file = db_path.parent / ".worker_shutdown_signal" + if self.shutdown_signal_file.exists(): + self.shutdown_signal_file.unlink() + + signal.signal(signal.SIGINT, self._handle_shutdown) + signal.signal(signal.SIGTERM, self._handle_shutdown) + + def _handle_shutdown(self, signum: int, _frame: object) -> None: + """Handle shutdown signals gracefully.""" + logger.info(f"Received signal {signum}, shutting down gracefully...") + self.stop() + + def start(self) -> None: + """Start the worker loop.""" + self.running = True + logger.info("Queue worker started") + + while self.running: + try: + if self.shutdown_signal_file.exists(): + logger.info("Shutdown signal detected, stopping worker...") + self.stop() + break + + job = self.db.get_next_pending_job() + + if job: + self._process_job(job) + else: + time.sleep(self.check_interval) + + except Exception as e: + logger.error(f"Error in worker loop: {e}", exc_info=True) + time.sleep(self.check_interval) + + logger.info("Queue worker stopped") + + if self.shutdown_signal_file.exists(): + self.shutdown_signal_file.unlink() + + def stop(self) -> None: + """Stop the worker loop.""" + self.running = False + + if self.current_process and self.current_process.poll() is None: + logger.info(f"Cancelling current job {self.current_job_id}") + self.current_process.terminate() + try: + self.current_process.wait(timeout=10) + except subprocess.TimeoutExpired: + self.current_process.kill() + + if self.current_job_id: + self.db.update_job_status(self.current_job_id, JobStatus.CANCELLED, error_message="Worker shutdown") + + def _process_job(self, job: dict) -> None: + """Process a single training job. + + Args: + job: Job dictionary from database + """ + job_id = job["id"] + self.current_job_id = job_id + + logger.info(f"Starting job {job_id} for dataset: {job['dataset_name']}") + + self.db.update_job_status(job_id, JobStatus.RUNNING) + + try: + cmd = self._build_training_command(job) + + self.current_process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + self.db.update_job_status(job_id, JobStatus.RUNNING, process_id=self.current_process.pid) + + log_lines = [] + for line in self.current_process.stdout: + if not self.running: + break + + current_job = self.db.get_job(job_id) + if current_job and current_job["status"] == JobStatus.CANCELLED: + logger.info(f"Job {job_id} was cancelled, terminating process") + self.current_process.terminate() + try: + self.current_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.current_process.kill() + break + + log_lines.append(line) + + if len(log_lines) > 500: + log_lines = log_lines[-500:] + + if "Step" in line or "steps" in line.lower(): + self.db.update_job_status(job_id, JobStatus.RUNNING, progress=line.strip(), logs="".join(log_lines)) + + returncode = self.current_process.wait() + + final_logs = "".join(log_lines) + + current_job = self.db.get_job(job_id) + if current_job and current_job["status"] == JobStatus.CANCELLED: + logger.info(f"Job {job_id} was cancelled") + self.db.update_job_status( + job_id, + JobStatus.CANCELLED, + logs=final_logs, + ) + elif returncode == 0: + logger.info(f"Job {job_id} completed successfully") + + output_dir = self._get_output_dir(job) + checkpoint_path = self._find_latest_checkpoint(output_dir) + + self.db.update_job_status( + job_id, + JobStatus.COMPLETED, + progress="Training completed", + logs=final_logs, + output_dir=str(output_dir) if output_dir else None, + checkpoint_path=str(checkpoint_path) if checkpoint_path else None, + ) + else: + logger.error(f"Job {job_id} failed with return code {returncode}") + self.db.update_job_status( + job_id, + JobStatus.FAILED, + error_message=f"Training process exited with code {returncode}", + logs=final_logs, + ) + + except Exception as e: + logger.error(f"Error processing job {job_id}: {e}", exc_info=True) + current_job = self.db.get_job(job_id) + if current_job and current_job["status"] != JobStatus.CANCELLED: + self.db.update_job_status(job_id, JobStatus.FAILED, error_message=str(e)) + + finally: + self.current_process = None + self.current_job_id = None + + def _build_training_command(self, job: dict) -> list[str]: + """Build the training command from job parameters. + + Args: + job: Job dictionary + + Returns: + Command as list of strings + """ + params = job["params"] + + python_exe = sys.executable + + scripts_dir = Path(__file__).parent.parent + train_script = scripts_dir / "train_cli.py" + + cmd = [ + python_exe, + str(train_script), + "--job-id", + str(job["id"]), + "--dataset", + job["dataset_name"], + "--model-version", + params["model_source"], + "--learning-rate", + str(params["learning_rate"]), + "--steps", + str(params["steps"]), + "--lora-rank", + str(params["lora_rank"]), + "--batch-size", + str(params["batch_size"]), + "--width", + str(params["width"]), + "--height", + str(params["height"]), + "--num-frames", + str(params["num_frames"]), + "--validation-prompt", + params["validation_prompt"], + "--validation-interval", + str(params["validation_interval"]), + ] + + if params.get("id_token"): + cmd.extend(["--id-token", params["id_token"]]) + + if params.get("push_to_hub"): + cmd.append("--push-to-hub") + if params.get("hf_model_id"): + cmd.extend(["--hf-model-id", params["hf_model_id"]]) + if params.get("hf_token"): + os.environ["HF_TOKEN"] = params["hf_token"] + + return cmd + + def _get_output_dir(self, job: dict) -> Optional[Path]: + """Get the output directory for a job. + + Args: + job: Job dictionary + + Returns: + Output directory path or None + """ + params = job["params"] + project_root = Path(__file__).parent.parent.parent + output_dir = project_root / "outputs" / f"lora_r{params['lora_rank']}_job{job['id']}" + + if output_dir.exists(): + return output_dir + return None + + def _find_latest_checkpoint(self, output_dir: Optional[Path]) -> Optional[Path]: + """Find the latest checkpoint in the output directory. + + Args: + output_dir: Output directory path + + Returns: + Path to latest checkpoint or None + """ + if not output_dir or not output_dir.exists(): + return None + + checkpoints_dir = output_dir / "checkpoints" + if not checkpoints_dir.exists(): + return None + + checkpoints = list(checkpoints_dir.glob("*.safetensors")) + if not checkpoints: + return None + + return max(checkpoints, key=lambda p: p.stat().st_mtime) + + +def main() -> None: + """Main entry point for the queue worker.""" + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + + project_root = Path(__file__).parent.parent.parent + db_path = project_root / "jobs.db" + + worker = QueueWorker(db_path) + + try: + worker.start() + except KeyboardInterrupt: + logger.info("Received keyboard interrupt") + worker.stop() + + +if __name__ == "__main__": + main() diff --git a/scripts/start_gradio_with_worker.py b/scripts/start_gradio_with_worker.py new file mode 100755 index 0000000..7321a26 --- /dev/null +++ b/scripts/start_gradio_with_worker.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +""" +Start both the queue worker and Gradio UI together. + +This script launches: +1. Queue worker in the background to process training jobs +2. Gradio UI for user interaction + +Both processes will be managed together, and stopping this script will stop both. +""" + +import logging +import signal +import subprocess +import sys +import time +from pathlib import Path + +# Configure logging for terminal output (alternative to print statements) +logging.basicConfig( + level=logging.INFO, + format="%(message)s", # Simple format for terminal output + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) + +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +class ServiceManager: + """Manage worker and UI processes.""" + + def __init__(self): + self.worker_process = None + self.ui_process = None + self.running = True + + signal.signal(signal.SIGINT, self._handle_shutdown) + signal.signal(signal.SIGTERM, self._handle_shutdown) + + def _handle_shutdown(self, _signum: int, _frame: object) -> None: + """Handle shutdown signals.""" + self.stop() + sys.exit(0) + + def start(self) -> bool: + """Start both worker and UI processes.""" + python_exe = sys.executable + scripts_dir = Path(__file__).parent + + logger.info("🚀 Starting LTX-Video Trainer with Queue System") + logger.info("=" * 60) + logger.info("\n📋 Starting job worker...") + + worker_script = scripts_dir / "jobs" / "run_worker.py" + self.worker_process = subprocess.Popen( + [python_exe, str(worker_script)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + time.sleep(2) + + if self.worker_process.poll() is not None: + logger.error("❌ Worker failed to start") + if self.worker_process.stdout: + logger.error("\n📋 Worker Error Output:") + logger.error("-" * 60) + for line in self.worker_process.stdout: + logger.error(line.rstrip()) + logger.error("-" * 60) + return False + + logger.info("✅ Queue worker started") + logger.info("\n🎨 Starting Gradio UI...") + + ui_script = scripts_dir / "app_gradio_v2.py" + self.ui_process = subprocess.Popen( + [python_exe, str(ui_script)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + time.sleep(3) + + if self.ui_process.poll() is not None: + logger.error("❌ UI failed to start") + if self.ui_process.stdout: + logger.error("\n📋 UI Error Output:") + logger.error("-" * 60) + for line in self.ui_process.stdout: + logger.error(line.rstrip()) + logger.error("-" * 60) + self.stop() + return False + + logger.info("✅ Gradio UI started") + logger.info("\n" + "=" * 60) + logger.info("🎉 All services running!") + logger.info("📍 Open your browser to: http://localhost:7860") + logger.info("💡 Press Ctrl+C to stop all services") + logger.info("=" * 60 + "\n") + + return True + + def monitor(self) -> None: + """Monitor both processes and restart if needed.""" + while self.running: + if self.worker_process and self.worker_process.poll() is not None: + logger.error("⚠️ Queue worker stopped unexpectedly") + if self.worker_process.stdout: + remaining = self.worker_process.stdout.read() + if remaining: + logger.error("\n📋 Worker Error Output:") + logger.error("-" * 60) + logger.error(remaining) + logger.error("-" * 60) + self.running = False + break + + if self.ui_process and self.ui_process.poll() is not None: + logger.error("⚠️ UI stopped unexpectedly") + if self.ui_process.stdout: + remaining = self.ui_process.stdout.read() + if remaining: + logger.error("\n📋 UI Error Output:") + logger.error("-" * 60) + logger.error(remaining) + logger.error("-" * 60) + self.running = False + break + + if self.ui_process and self.ui_process.stdout: + line = self.ui_process.stdout.readline() + if line: + pass + + if self.worker_process and self.worker_process.stdout: + line = self.worker_process.stdout.readline() + if line: + pass + + time.sleep(0.1) + + def stop(self) -> None: + """Stop both processes.""" + logger.info("\n🛑 Shutting down services...") + self.running = False + + if self.ui_process: + self.ui_process.terminate() + try: + self.ui_process.wait(timeout=5) + except subprocess.TimeoutExpired: + self.ui_process.kill() + + if self.worker_process: + self.worker_process.terminate() + try: + self.worker_process.wait(timeout=10) + except subprocess.TimeoutExpired: + self.worker_process.kill() + + +def main() -> None: + """Main entry point.""" + manager = ServiceManager() + + if not manager.start(): + sys.exit(1) + + try: + manager.monitor() + except KeyboardInterrupt: + pass + finally: + manager.stop() + + +if __name__ == "__main__": + main() diff --git a/scripts/train_cli.py b/scripts/train_cli.py new file mode 100755 index 0000000..82a083a --- /dev/null +++ b/scripts/train_cli.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +""" +Command-line training script that integrates with the job queue system. + +This script can be called directly with parameters or via the queue worker. +""" + +import argparse +import os +import shutil +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +import yaml +from huggingface_hub import login + +from ltxv_trainer.config import LtxvTrainerConfig +from ltxv_trainer.trainer import LtxvTrainer +from scripts.jobs.database import JobDatabase, JobStatus + + +def generate_config(args: argparse.Namespace, output_dir: Path) -> dict: + """Generate training configuration from command-line arguments. + + Args: + args: Parsed command-line arguments + output_dir: Output directory for training + + Returns: + Configuration dictionary + """ + config = { + "model": { + "model_source": args.model_version, + "training_mode": "lora", + "load_checkpoint": None, + }, + "lora": { + "rank": args.lora_rank, + "alpha": args.lora_rank, # Usually alpha = rank + "dropout": 0.0, + "target_modules": ["to_k", "to_q", "to_v", "to_out.0"], + }, + "conditioning": { + "mode": "none", + "first_frame_conditioning_p": 0.1, + }, + "optimization": { + "learning_rate": args.learning_rate, + "steps": args.steps, + "batch_size": args.batch_size, + "gradient_accumulation_steps": 1, + "max_grad_norm": 1.0, + "optimizer_type": "adamw", + "scheduler_type": "linear", + "scheduler_params": {}, + "enable_gradient_checkpointing": False, + }, + "acceleration": { + "mixed_precision_mode": "bf16", + "quantization": None, + "load_text_encoder_in_8bit": True, + "compile_with_inductor": False, + "compilation_mode": "reduce-overhead", + }, + "data": { + "preprocessed_data_root": str(args.dataset_dir), + "num_dataloader_workers": 2, + }, + "validation": { + "prompts": [args.validation_prompt], + "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted", + "images": None, + "reference_videos": None, + "video_dims": [args.width, args.height, args.num_frames], + "seed": 42, + "inference_steps": 30, + "interval": args.validation_interval, + "videos_per_prompt": 1, + "guidance_scale": 3.5, + "skip_initial_validation": False, + }, + "checkpoints": { + "interval": None, + "keep_last_n": 1, + }, + "hub": { + "push_to_hub": args.push_to_hub, + "hub_model_id": args.hf_model_id if args.push_to_hub else None, + }, + "flow_matching": { + "timestep_sampling_mode": "shifted_logit_normal", + "timestep_sampling_params": {}, + }, + "wandb": { + "enabled": False, + "project": "ltxv-trainer", + "entity": None, + "tags": [], + "log_validation_videos": True, + }, + "seed": 42, + "output_dir": str(output_dir), + } + + return config + + +def setup_dataset(dataset_name: str, datasets_root: Path, training_data_dir: Path) -> None: + """Set up the training dataset by copying from managed datasets. + + Args: + dataset_name: Name of the managed dataset + datasets_root: Root directory of managed datasets + training_data_dir: Temporary training data directory + """ + managed_dataset_dir = datasets_root / dataset_name + managed_dataset_json = managed_dataset_dir / "dataset.json" + + if not managed_dataset_json.exists(): + raise ValueError(f"Dataset '{dataset_name}' not found") + + if training_data_dir.exists(): + shutil.rmtree(training_data_dir) + training_data_dir.mkdir(parents=True) + + shutil.copy2(managed_dataset_json, training_data_dir / "dataset.json") + + precomputed_dir = managed_dataset_dir / ".precomputed" + if precomputed_dir.exists(): + shutil.copytree(precomputed_dir, training_data_dir / ".precomputed") + + +def train_with_progress_callback(trainer: LtxvTrainer, db: JobDatabase, job_id: int) -> None: + """Train with progress updates to database. + + Args: + trainer: Trainer instance + db: Job database + job_id: Current job ID + """ + + def progress_callback(step: int, total_steps: int, sampled_videos: list[Path] | None = None) -> None: + """Update job progress in database.""" + progress_pct = (step / total_steps) * 100 + progress_msg = f"Step {step}/{total_steps} ({progress_pct:.1f}%)" + + validation_sample = None + if sampled_videos: + validation_sample = str(sampled_videos[0]) + + db.update_job_status(job_id, JobStatus.RUNNING, progress=progress_msg, validation_sample=validation_sample) + + trainer.train(step_callback=progress_callback) + + +def main() -> int | None: + """Main entry point for CLI training.""" + parser = argparse.ArgumentParser(description="Train LTX-Video LoRA model") + + parser.add_argument("--job-id", type=int, help="Job ID for queue tracking") + + parser.add_argument("--dataset", required=True, help="Dataset name") + + parser.add_argument("--model-version", default="0.9.1", help="Model version") + + parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate") + parser.add_argument("--steps", type=int, default=1500, help="Training steps") + parser.add_argument("--lora-rank", type=int, default=128, help="LoRA rank") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size") + + parser.add_argument("--width", type=int, default=768, help="Video width") + parser.add_argument("--height", type=int, default=768, help="Video height") + parser.add_argument("--num-frames", type=int, default=25, help="Number of frames") + parser.add_argument("--id-token", type=str, default="", help="LoRA ID token") + + parser.add_argument( + "--validation-prompt", + default="a professional portrait video of a person", + help="Validation prompt", + ) + parser.add_argument("--validation-interval", type=int, default=100, help="Validation interval") + + parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub") + parser.add_argument("--hf-model-id", type=str, help="HuggingFace model ID") + + args = parser.parse_args() + + project_root = Path(__file__).parent.parent + datasets_root = project_root / "datasets" + training_data_dir = ( + project_root / "training_data" / f"job_{args.job_id}" if args.job_id else project_root / "training_data" + ) + output_dir = ( + project_root / "outputs" / f"lora_r{args.lora_rank}_job{args.job_id}" + if args.job_id + else project_root / "outputs" / f"lora_r{args.lora_rank}" + ) + + db = None + if args.job_id: + db_path = project_root / "jobs.db" + db = JobDatabase(db_path) + + try: + if "HF_TOKEN" in os.environ: + login(token=os.environ["HF_TOKEN"]) + + args.dataset_dir = training_data_dir + setup_dataset(args.dataset, datasets_root, training_data_dir) + + config = generate_config(args, output_dir) + + output_dir.mkdir(parents=True, exist_ok=True) + config_path = output_dir / "training_config.yaml" + with open(config_path, "w") as f: + yaml.dump(config, f, indent=2) + + trainer_config = LtxvTrainerConfig(**config) + trainer = LtxvTrainer(trainer_config) + + if db and args.job_id: + train_with_progress_callback(trainer, db, args.job_id) + else: + trainer.train() + + return 0 + + except Exception as e: + if db and args.job_id: + db.update_job_status(args.job_id, JobStatus.FAILED, error_message=str(e)) + raise + + finally: + if db: + db.close() + + +if __name__ == "__main__": + sys.exit(main())