diff --git a/.gitignore b/.gitignore
index 19647a0..ff9a13d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -26,4 +26,7 @@ configs/*.yaml
!configs/ltxv_2b_full.yaml
!configs/ltxv_2b_lora.yaml
!configs/ltxv_2b_lora_low_vram.yaml
-outputs
+outputs/
+datasets/
+training_data/
+*.db
diff --git a/scripts/app_gradio_v2.py b/scripts/app_gradio_v2.py
new file mode 100644
index 0000000..40690ac
--- /dev/null
+++ b/scripts/app_gradio_v2.py
@@ -0,0 +1,2221 @@
+"""Gradio interface for LTX Video Trainer."""
+
+import datetime
+import json
+import logging
+import os
+import shutil
+from dataclasses import dataclass
+from datetime import timezone
+from pathlib import Path
+
+import gradio as gr
+import torch
+import yaml
+from huggingface_hub import login
+
+from ltxv_trainer.captioning import (
+ DEFAULT_VLM_CAPTION_INSTRUCTION,
+ CaptionerType,
+ create_captioner,
+)
+from ltxv_trainer.hf_hub_utils import convert_video_to_gif
+from ltxv_trainer.model_loader import (
+ LtxvModelVersion,
+)
+from ltxv_trainer.trainer import LtxvTrainer, LtxvTrainerConfig
+from scripts.dataset_manager import DatasetManager
+from scripts.jobs.database import JobDatabase, JobStatus
+from scripts.preprocess_dataset import preprocess_dataset
+from scripts.process_videos import parse_resolution_buckets
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
+
+torch.cuda.empty_cache()
+BASE_DIR = Path(__file__).parent
+PROJECT_ROOT = BASE_DIR.parent
+OUTPUTS_DIR = PROJECT_ROOT / "outputs"
+TRAINING_DATA_DIR = PROJECT_ROOT / "training_data"
+DATASETS_DIR = PROJECT_ROOT / "datasets"
+VALIDATION_SAMPLES_DIR = OUTPUTS_DIR / "validation_samples"
+
+OUTPUTS_DIR.mkdir(exist_ok=True)
+TRAINING_DATA_DIR.mkdir(exist_ok=True)
+VALIDATION_SAMPLES_DIR.mkdir(exist_ok=True)
+DATASETS_DIR.mkdir(exist_ok=True)
+
+
+@dataclass
+class TrainingConfigParams:
+ """Parameters for generating training configuration."""
+
+ model_source: str
+ learning_rate: float
+ steps: int
+ lora_rank: int
+ batch_size: int
+ validation_prompt: str
+ video_dims: tuple[int, int, int]
+ validation_interval: int = 100
+ push_to_hub: bool = False
+ hub_model_id: str | None = None
+
+
+@dataclass
+class TrainingState:
+ """State for tracking training progress."""
+
+ status: str | None = None
+ progress: str | None = None
+ validation: str | None = None
+ download: str | None = None
+ error: str | None = None
+ hf_repo: str | None = None
+ checkpoint_path: str | None = None
+
+ def reset(self) -> None:
+ """Reset state to initial values."""
+ self.status = "running"
+ self.progress = None
+ self.validation = None
+ self.download = None
+ self.error = None
+ self.hf_repo = None
+ self.checkpoint_path = None
+
+ def update(self, **kwargs: str | int | None) -> None:
+ """Update state with provided values."""
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+
+
+@dataclass
+class TrainingParams:
+ dataset_name: str
+ validation_prompt: str
+ learning_rate: float
+ steps: int
+ lora_rank: int
+ batch_size: int
+ model_source: str
+ width: int
+ height: int
+ num_frames: int
+ push_to_hub: bool
+ hf_model_id: str
+ hf_token: str | None = None
+ id_token: str | None = None
+ validation_interval: int = 100
+
+
+def _handle_validation_sample(step: int, video_path: Path) -> str | None:
+ """Handle validation sample conversion and storage.
+
+ Args:
+ step: Current training step
+ video_path: Path to the validation video
+
+ Returns:
+ Path to the GIF file if successful, None otherwise
+ """
+ gif_path = VALIDATION_SAMPLES_DIR / f"sample_step_{step}.gif"
+ try:
+ convert_video_to_gif(video_path, gif_path)
+ logger.info(f"New validation sample converted to GIF at step {step}: {gif_path}")
+ return str(gif_path)
+ except Exception as e:
+ logger.error(f"Failed to convert validation video to GIF: {e}")
+ return None
+
+
+def generate_training_config(params: TrainingConfigParams, training_data_dir: str) -> dict:
+ """Generate training configuration from parameters.
+
+ Args:
+ params: Training configuration parameters
+ training_data_dir: Directory containing training data
+
+ Returns:
+ Dictionary containing the complete training configuration
+ """
+ template_path = Path(__file__).parent.parent / "configs" / "ltxv_13b_lora_template.yaml"
+ with open(template_path) as f:
+ config = yaml.safe_load(f)
+
+ config["model"]["model_source"] = params.model_source
+ config["lora"]["rank"] = params.lora_rank
+ config["lora"]["alpha"] = params.lora_rank
+ config["optimization"]["learning_rate"] = params.learning_rate
+ config["optimization"]["steps"] = params.steps
+ config["optimization"]["batch_size"] = params.batch_size
+ config["data"]["preprocessed_data_root"] = str(training_data_dir)
+ config["output_dir"] = str(OUTPUTS_DIR / f"lora_r{params.lora_rank}")
+
+ config["hub"] = {
+ "push_to_hub": params.push_to_hub,
+ "hub_model_id": params.hub_model_id if params.push_to_hub else None,
+ }
+
+ width, height, num_frames = params.video_dims
+ config["validation"] = {
+ "prompts": [params.validation_prompt],
+ "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
+ "video_dims": [width, height, num_frames],
+ "seed": 42,
+ "inference_steps": 30,
+ "interval": params.validation_interval,
+ "videos_per_prompt": 1,
+ "guidance_scale": 3.5,
+ }
+
+ if "validation" in config and "images" not in config["validation"]:
+ config["validation"]["images"] = None
+
+ return config
+
+
+class GradioUI:
+ """Class to manage Gradio UI components and state."""
+
+ def __init__(self) -> None:
+ self.training_state = TrainingState()
+ self.dataset_manager = DatasetManager(datasets_root=DATASETS_DIR)
+ self.current_dataset = None
+
+ db_path = PROJECT_ROOT / "jobs.db"
+ self.job_db = JobDatabase(db_path)
+ self.current_job_id = None
+
+ self.validation_prompt = None
+ self.status_output = None
+ self.progress_output = None
+ self.download_btn = None
+ self.hf_repo_link = None
+
+ def reset_interface(self) -> dict:
+ """Reset the interface and clean up all training data.
+
+ Returns:
+ Dictionary of Gradio component updates
+ """
+ self.training_state.reset()
+ self.current_job_id = None
+
+ if TRAINING_DATA_DIR.exists():
+ shutil.rmtree(TRAINING_DATA_DIR)
+ TRAINING_DATA_DIR.mkdir(exist_ok=True)
+
+ if OUTPUTS_DIR.exists():
+ shutil.rmtree(OUTPUTS_DIR)
+ OUTPUTS_DIR.mkdir(exist_ok=True)
+ VALIDATION_SAMPLES_DIR.mkdir(exist_ok=True)
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ return {
+ self.validation_prompt: gr.update(
+ value="a professional portrait video of a person with blurry bokeh background",
+ info="Include the LoRA ID token (e.g., <lora>) in this prompt if desired.",
+ ),
+ self.status_output: gr.update(value=""),
+ self.progress_output: gr.update(value=""),
+ self.download_btn: gr.update(visible=False),
+ self.hf_repo_link: gr.update(visible=False, value=""),
+ }
+
+ def get_model_path(self) -> str | None:
+ """Get the path to the trained model file."""
+ if self.training_state.download and Path(self.training_state.download).exists():
+ return self.training_state.download
+ return None
+
+ def update_progress(self) -> tuple[str, str, gr.update, str, gr.update]:
+ """Update the UI with current training progress."""
+ # Default return values
+ job_display = ""
+ progress_msg = ""
+ file_update = gr.update(visible=False)
+ hf_html = ""
+ hf_update = gr.update(visible=False)
+
+ # Handle current job from queue
+ if self.current_job_id:
+ job = self.job_db.get_job(self.current_job_id)
+ if job:
+ job_display = f"**Current Job:** #{self.current_job_id} | Dataset: `{job['dataset_name']}`"
+ status = job["status"]
+ progress = job.get("progress", "")
+
+ status_messages = {
+ JobStatus.PENDING: "Waiting in queue...",
+ JobStatus.RUNNING: progress if progress else "Training in progress...",
+ JobStatus.COMPLETED: "Training completed!",
+ JobStatus.FAILED: "Training failed",
+ JobStatus.CANCELLED: "Job cancelled",
+ }
+ progress_msg = status_messages.get(status, "")
+
+ # Handle completed job artifacts
+ if status == JobStatus.COMPLETED:
+ hf_repo = job.get("hf_repo_url", "")
+ checkpoint_path = job.get("checkpoint_path", "")
+
+ if hf_repo:
+ hf_html = f'View model on HuggingFace Hub'
+ hf_update = gr.update(visible=True)
+ elif checkpoint_path and Path(checkpoint_path).exists():
+ file_update = gr.update(
+ value=checkpoint_path, visible=True, label=f"Download {Path(checkpoint_path).name}"
+ )
+
+ return (job_display, progress_msg, file_update, hf_html, hf_update)
+
+ # Handle legacy direct training mode
+ if self.training_state.status is not None:
+ job_display = "**Direct Training Mode** (Legacy)"
+ progress_msg = self.training_state.progress
+ status = self.training_state.status
+
+ if status == "complete":
+ if self.training_state.hf_repo:
+ hf_html = (
+ f'View model on HuggingFace Hub'
+ )
+ hf_update = gr.update(visible=True)
+ elif self.training_state.download and Path(self.training_state.download).exists():
+ file_update = gr.update(
+ value=self.training_state.download,
+ visible=True,
+ label=f"Download {Path(self.training_state.download).name}",
+ )
+
+ return (job_display, progress_msg, file_update, hf_html, hf_update)
+
+ # No active training
+ return (job_display, progress_msg, file_update, hf_html, hf_update)
+
+ def _save_checkpoint(self, saved_path: Path, trainer_config: LtxvTrainerConfig) -> tuple[Path, str | None]:
+ """Save and copy the checkpoint to a permanent location.
+
+ Args:
+ saved_path: Path where the checkpoint was initially saved
+ trainer_config: Training configuration
+
+ Returns:
+ Tuple of (permanent checkpoint path, HF repo URL if applicable)
+ """
+ permanent_checkpoint_dir = OUTPUTS_DIR / "checkpoints"
+ permanent_checkpoint_dir.mkdir(parents=True, exist_ok=True)
+
+ timestamp = datetime.datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S")
+ checkpoint_filename = f"comfy_lora_checkpoint_{timestamp}.safetensors"
+ permanent_checkpoint_path = permanent_checkpoint_dir / checkpoint_filename
+
+ try:
+ shutil.copy2(saved_path, permanent_checkpoint_path)
+ logger.info(f"Checkpoint copied to permanent location: {permanent_checkpoint_path}")
+ except Exception as e:
+ logger.error(f"Failed to copy checkpoint: {e}")
+ permanent_checkpoint_path = saved_path
+
+ hf_repo = (
+ f"https://huggingface.co/{trainer_config.hub.hub_model_id}" if trainer_config.hub.hub_model_id else None
+ )
+
+ return permanent_checkpoint_path, hf_repo
+
+ def _preprocess_dataset(
+ self,
+ dataset_file: Path,
+ model_source: str,
+ width: int,
+ height: int,
+ num_frames: int,
+ id_token: str | None = None,
+ ) -> tuple[bool, str | None]:
+ """Preprocess the dataset by computing video latents and text embeddings.
+
+ Args:
+ dataset_file: Path to the dataset.json file
+ model_source: Model source identifier
+ width: Video width
+ height: Video height
+ num_frames: Number of frames
+ id_token: Optional token to prepend to captions (for LoRA training)
+
+ Returns:
+ Tuple of (success, error_message)
+ """
+ try:
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ precomputed_dir = TRAINING_DATA_DIR / ".precomputed"
+ if precomputed_dir.exists():
+ shutil.rmtree(precomputed_dir)
+
+ resolution_buckets = f"{width}x{height}x{num_frames}"
+ parsed_buckets = parse_resolution_buckets(resolution_buckets)
+
+ preprocess_dataset(
+ dataset_file=str(dataset_file),
+ caption_column="caption",
+ video_column="media_path",
+ resolution_buckets=parsed_buckets,
+ batch_size=1,
+ output_dir=None,
+ id_token=id_token,
+ vae_tiling=False,
+ decode_videos=True,
+ model_source=model_source,
+ device="cuda" if torch.cuda.is_available() else "cpu",
+ load_text_encoder_in_8bit=False,
+ )
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return True, None
+
+ except Exception as e:
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ return False, f"Error preprocessing dataset: {e!s}"
+
+ def _should_preprocess_data(
+ self,
+ width: int,
+ height: int,
+ num_frames: int,
+ videos: list[str],
+ ) -> bool:
+ """Check if data needs to be preprocessed based on resolution changes.
+
+ Args:
+ width: Video width
+ height: Video height
+ num_frames: Number of frames
+ videos: List of video file paths
+
+ Returns:
+ True if preprocessing is needed, False otherwise
+ """
+ resolution_file = TRAINING_DATA_DIR / ".resolution_config"
+ current_resolution = f"{width}x{height}x{num_frames}"
+ needs_to_copy = False
+ for video in videos:
+ if Path(video).exists():
+ needs_to_copy = True
+ if needs_to_copy:
+ logger.info("Videos provided, will copy them to training directory.")
+ return True, needs_to_copy
+
+ if not resolution_file.exists() or not (TRAINING_DATA_DIR / "captions.json").exists():
+ return True, needs_to_copy
+
+ try:
+ with open(resolution_file) as f:
+ previous_resolution = f.read().strip()
+ return previous_resolution != current_resolution, needs_to_copy
+ except Exception:
+ return True, needs_to_copy
+
+ def _save_resolution_config(
+ self,
+ width: int,
+ height: int,
+ num_frames: int,
+ ) -> None:
+ """Save current resolution configuration.
+
+ Args:
+ width: Video width
+ height: Video height
+ num_frames: Number of frames
+ """
+ resolution_file = TRAINING_DATA_DIR / ".resolution_config"
+ current_resolution = f"{width}x{height}x{num_frames}"
+
+ with open(resolution_file, "w") as f:
+ f.write(current_resolution)
+
+ def _sync_captions_from_ui(
+ self, params: TrainingParams, training_captions_file: Path
+ ) -> tuple[dict[str, str] | None, str | None]:
+ """Sync captions from the UI to captions.json. Returns (captions_data, error_message)."""
+ if params.captions_json:
+ try:
+ dataset = json.loads(params.captions_json)
+ # Convert list of dicts to captions_data dict
+ captions_data = {item["media_path"]: item["caption"] for item in dataset}
+ # Save to captions.json (overwrite every time)
+ with open(training_captions_file, "w") as f:
+ json.dump(captions_data, f, indent=2)
+ return captions_data, None
+ except Exception as e:
+ return None, f"Invalid captions JSON: {e!s}"
+ else:
+ return None, "No captions found in the UI. Please process videos first."
+
+ # ruff: noqa: PLR0912
+ def start_training(
+ self,
+ params: TrainingParams,
+ ) -> tuple[str, gr.update]:
+ """Queue a training job."""
+ # Validate dataset exists
+ if not params.dataset_name:
+ return "Please select a dataset from the Datasets tab", gr.update(interactive=True)
+
+ managed_dataset_dir = self.dataset_manager.datasets_root / params.dataset_name
+ managed_dataset_json = managed_dataset_dir / "dataset.json"
+
+ if not managed_dataset_json.exists():
+ return f"Dataset '{params.dataset_name}' not found or has no videos", gr.update(interactive=True)
+
+ # Check if dataset is preprocessed
+ precomputed_dir = managed_dataset_dir / ".precomputed"
+ if not precomputed_dir.exists():
+ return (
+ f"Dataset '{params.dataset_name}' is not preprocessed. Please preprocess it in the Datasets tab first.",
+ gr.update(interactive=True),
+ )
+
+ # Create job parameters
+ job_params = {
+ "model_source": params.model_source,
+ "learning_rate": params.learning_rate,
+ "steps": params.steps,
+ "lora_rank": params.lora_rank,
+ "batch_size": params.batch_size,
+ "width": params.width,
+ "height": params.height,
+ "num_frames": params.num_frames,
+ "id_token": params.id_token or "",
+ "validation_prompt": params.validation_prompt,
+ "validation_interval": params.validation_interval,
+ "push_to_hub": params.push_to_hub,
+ "hf_model_id": params.hf_model_id if params.push_to_hub else "",
+ "hf_token": params.hf_token if params.push_to_hub else "",
+ }
+
+ # Create job in database
+ try:
+ job_id = self.job_db.create_job(params.dataset_name, job_params)
+ self.current_job_id = job_id
+
+ return (
+ f"Job #{job_id} created and queued for training on dataset '{params.dataset_name}'.\n"
+ f"The worker will start training automatically. Check the Queue tab for status.",
+ gr.update(interactive=True),
+ )
+ except Exception as e:
+ return f"Failed to create training job: {e}", gr.update(interactive=True)
+
+ def start_training_direct(
+ self,
+ params: TrainingParams,
+ ) -> tuple[str, gr.update]:
+ """Start the training process directly (legacy method)."""
+ if params.hf_token:
+ login(token=params.hf_token)
+
+ try:
+ # Clear any existing CUDA cache
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ # Set training status
+ self.training_state.reset() # This sets status to "running"
+
+ # Prepare data directory
+ data_dir = TRAINING_DATA_DIR
+ data_dir.mkdir(exist_ok=True)
+
+ # Validate dataset exists
+ if not params.dataset_name:
+ return "Please select a dataset from the Datasets tab", gr.update(interactive=True)
+
+ # Use managed dataset
+ managed_dataset_dir = self.dataset_manager.datasets_root / params.dataset_name
+ managed_dataset_json = managed_dataset_dir / "dataset.json"
+
+ if not managed_dataset_json.exists():
+ return f"Dataset '{params.dataset_name}' not found or has no videos", gr.update(interactive=True)
+
+ # Copy dataset.json to training directory
+ training_captions_file = data_dir / "captions.json"
+ shutil.copy2(managed_dataset_json, training_captions_file)
+
+ # Load the dataset to get video paths
+ with open(training_captions_file) as f:
+ dataset = json.load(f)
+
+ # Copy videos from managed dataset to training directory
+ for item in dataset:
+ src_video = managed_dataset_dir / item["media_path"]
+ dest_video = data_dir / Path(item["media_path"]).name
+
+ if not dest_video.exists() and src_video.exists():
+ shutil.copy2(src_video, dest_video)
+
+ # Update the media_path to be relative to data_dir
+ item["media_path"] = Path(item["media_path"]).name
+
+ # Save updated dataset with corrected paths
+ with open(training_captions_file, "w") as f:
+ json.dump(dataset, f, indent=2)
+
+ # Check if preprocessing is needed
+ needs_preprocessing = self._should_preprocess_data(params.width, params.height, params.num_frames, [])[0]
+
+ # Preprocess if needed (first time or resolution changed)
+ if needs_preprocessing:
+ # Clean up existing precomputed data
+ precomputed_dir = TRAINING_DATA_DIR / ".precomputed"
+ if precomputed_dir.exists():
+ shutil.rmtree(precomputed_dir)
+
+ success, error_msg = self._preprocess_dataset(
+ dataset_file=training_captions_file,
+ model_source=params.model_source,
+ width=params.width,
+ height=params.height,
+ num_frames=params.num_frames,
+ id_token=params.id_token,
+ )
+ if not success:
+ return error_msg, gr.update(interactive=True)
+
+ # Save current resolution config after successful preprocessing
+ self._save_resolution_config(params.width, params.height, params.num_frames)
+
+ # Generate training config
+ config_params = TrainingConfigParams(
+ model_source=params.model_source,
+ learning_rate=params.learning_rate,
+ steps=params.steps,
+ lora_rank=params.lora_rank,
+ batch_size=params.batch_size,
+ validation_prompt=params.validation_prompt,
+ video_dims=(params.width, params.height, params.num_frames),
+ validation_interval=params.validation_interval,
+ push_to_hub=params.push_to_hub,
+ hub_model_id=params.hf_model_id if params.push_to_hub else None,
+ )
+
+ config = generate_training_config(config_params, str(data_dir))
+ config_path = OUTPUTS_DIR / "train_config.yaml"
+ with open(config_path, "w") as f:
+ yaml.dump(config, f, indent=4)
+
+ # Run training
+ self.run_training(config_path)
+
+ return "Training completed!", gr.update(interactive=True)
+
+ except Exception as e:
+ return f"Error during training: {e!s}", gr.update(interactive=True)
+ finally:
+ # Clean up CUDA cache
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+
+ def run_training(self, config_path: Path) -> None:
+ """Run the training process and update progress."""
+ # Reset training state at the start
+ self.training_state.reset()
+
+ try:
+ # Load config from YAML
+ with open(config_path) as f:
+ config_dict = yaml.safe_load(f)
+
+ # Initialize trainer config and trainer
+ trainer_config = LtxvTrainerConfig(**config_dict)
+ trainer = LtxvTrainer(trainer_config)
+
+ def training_callback(step: int, total_steps: int, sampled_videos: list[Path] | None = None) -> None:
+ """Callback function to update training progress and show samples."""
+ # Update progress
+ progress_pct = (step / total_steps) * 100
+ self.training_state.update(progress=f"Step {step}/{total_steps} ({progress_pct:.1f}%)")
+
+ # Update validation video at validation intervals
+ if step % trainer_config.validation.interval == 0 and sampled_videos:
+ # Convert the first sample to GIF
+ gif_path = _handle_validation_sample(step, sampled_videos[0])
+ if gif_path:
+ self.training_state.update(validation=gif_path)
+
+ logger.info("Starting training...")
+
+ # Start training with callback
+ saved_path, stats = trainer.train(disable_progress_bars=False, step_callback=training_callback)
+
+ # Save checkpoint and get paths
+ permanent_checkpoint_path, hf_repo = self._save_checkpoint(saved_path, trainer_config)
+
+ # Update training outputs with completion status
+ self.training_state.update(
+ status="complete",
+ download=str(permanent_checkpoint_path),
+ hf_repo=hf_repo,
+ checkpoint_path=str(permanent_checkpoint_path),
+ )
+
+ logger.info(f"Training completed. Model saved to {permanent_checkpoint_path}")
+ logger.info(f"Training stats: {stats}")
+
+ except Exception as e:
+ logger.error(f"Training failed: {e!s}", exc_info=True)
+ self.training_state.update(status="failed", error=str(e))
+ raise
+ finally:
+ # Don't reset current_job here - let the UI handle it
+ if self.training_state.status == "running":
+ self.training_state.update(status="failed")
+
+ def create_new_dataset(self, name: str) -> tuple[str, gr.Dropdown, list, dict, None, str, str, gr.Dropdown]:
+ """Create a new dataset.
+
+ Args:
+ name: Name of the dataset
+
+ Returns:
+ Tuple of (status message, updated dropdown, empty gallery, empty stats, cleared fields, training dropdown)
+ """
+ if not name or not name.strip():
+ return "Please enter a dataset name", gr.update(), gr.update(), {}, None, "", "", gr.update()
+
+ try:
+ self.dataset_manager.create_dataset(name)
+ datasets = self.dataset_manager.list_datasets()
+ # Return updated dropdown with new dataset selected, and clear all other fields
+ return (
+ f"Created dataset: {name}",
+ gr.update(choices=datasets, value=name),
+ [], # Empty gallery
+ {"name": name, "total_videos": 0, "captioned": 0, "uncaptioned": 0, "preprocessed": False}, # Stats
+ None, # Clear selected video
+ "", # Clear video name
+ "", # Clear caption editor
+ gr.update(choices=datasets), # Update training tab dropdown
+ )
+ except Exception as e:
+ return f"Error: {e}", gr.update(), gr.update(), {}, None, "", "", gr.update()
+
+ def load_dataset(self, dataset_name: str) -> tuple[str | None, list, dict, None | str, str, str]:
+ """Load a dataset and display its contents.
+
+ Args:
+ dataset_name: Name of the dataset to load
+
+ Returns:
+ Tuple of (dataset name, gallery items, statistics, cleared video/name/caption)
+ """
+ if not dataset_name:
+ return None, [], {}, None, "", ""
+
+ self.current_dataset = dataset_name
+ items = self.dataset_manager.get_dataset_items(dataset_name)
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+
+ # Prepare gallery items (thumbnails with captions)
+ gallery_items = [
+ (
+ item["thumbnail"],
+ item["caption"][:50] + "..."
+ if len(item.get("caption", "")) > 50
+ else item.get("caption", "") or "No caption",
+ )
+ for item in items
+ if item["thumbnail"]
+ ]
+
+ # Clear the video preview and caption editor when switching datasets
+ return dataset_name, gallery_items, stats, None, "", ""
+
+ def upload_videos_to_dataset(self, files: list, dataset_name: str) -> tuple[str, gr.Gallery, dict]:
+ """Upload videos to a dataset.
+
+ Args:
+ files: List of video file paths
+ dataset_name: Name of the dataset
+
+ Returns:
+ Tuple of (status message, updated gallery, updated stats)
+ """
+ if not dataset_name:
+ return "Please select a dataset first", gr.update(), gr.update()
+
+ if not files:
+ return "No files selected", gr.update(), gr.update()
+
+ try:
+ result = self.dataset_manager.add_videos(dataset_name, files)
+
+ message = f"Added {len(result['added'])} videos"
+ if result["failed"]:
+ message += f", {len(result['failed'])} failed"
+ for failed in result["failed"][:3]: # Show first 3 failures
+ message += f"\n- {failed['video']}: {failed['error']}"
+
+ # Refresh gallery and stats
+ items = self.dataset_manager.get_dataset_items(dataset_name)
+ gallery_items = [
+ (
+ i["thumbnail"],
+ (
+ i.get("caption", "")[:50] + "..."
+ if len(i.get("caption", "")) > 50
+ else i.get("caption", "") or "No caption"
+ ),
+ )
+ for i in items
+ if i["thumbnail"]
+ ]
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+
+ return message, gr.update(value=gallery_items), stats
+ except Exception as e:
+ return f"Error: {e}", gr.update(), gr.update()
+
+ def upload_videos_with_references(
+ self, video_files: list, reference_files: list, dataset_name: str
+ ) -> tuple[str, gr.Gallery, dict]:
+ """Upload videos with reference videos to a dataset for IC-LoRA training.
+
+ Args:
+ video_files: List of target video file paths
+ reference_files: List of reference video file paths
+ dataset_name: Name of the dataset
+
+ Returns:
+ Tuple of (status message, updated gallery, updated stats)
+ """
+ if not dataset_name:
+ return "Please select a dataset first", gr.update(), gr.update()
+
+ if not video_files:
+ return "No target videos selected", gr.update(), gr.update()
+
+ if not reference_files:
+ return "No reference videos selected", gr.update(), gr.update()
+
+ try:
+ result = self.dataset_manager.add_videos(dataset_name, video_files, reference_files)
+
+ message = f"Added {len(result['added'])} video pairs (target + reference)"
+ if result["failed"]:
+ message += f", {len(result['failed'])} failed"
+ for failed in result["failed"][:3]: # Show first 3 failures
+ message += f"\n- {failed['video']}: {failed['error']}"
+
+ # Refresh gallery and stats
+ items = self.dataset_manager.get_dataset_items(dataset_name)
+ gallery_items = [
+ (
+ i["thumbnail"],
+ (
+ i.get("caption", "")[:50] + "..."
+ if len(i.get("caption", "")) > 50
+ else i.get("caption", "") or "No caption"
+ ),
+ )
+ for i in items
+ if i["thumbnail"]
+ ]
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+
+ return message, gr.update(value=gallery_items), stats
+ except Exception as e:
+ return f"Error: {e}", gr.update(), gr.update()
+
+ def select_video_from_gallery(self, evt: gr.SelectData, dataset_name: str) -> tuple[str | None, str, str]:
+ """Handle video selection from gallery.
+
+ Args:
+ evt: Selection event data
+ dataset_name: Name of the dataset
+
+ Returns:
+ Tuple of (video path, video name, caption)
+ """
+ if not dataset_name:
+ return None, "", ""
+
+ items = self.dataset_manager.get_dataset_items(dataset_name)
+ if evt.index >= len(items):
+ return None, "", ""
+
+ selected_item = items[evt.index]
+ video_path = selected_item["full_video_path"]
+ video_name = Path(selected_item["media_path"]).name
+ caption = selected_item.get("caption", "")
+
+ return video_path, video_name, caption
+
+ def save_caption_edit(self, dataset_name: str, video_name: str, caption: str) -> tuple[str, dict]:
+ """Save edited caption for a video.
+
+ Args:
+ dataset_name: Name of the dataset
+ video_name: Name of the video file
+ caption: New caption text
+
+ Returns:
+ Tuple of (status message, updated stats)
+ """
+ if not dataset_name or not video_name:
+ return "No video selected", gr.update()
+
+ try:
+ self.dataset_manager.update_caption(dataset_name, video_name, caption)
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+ return f"Caption saved for {video_name}", stats
+ except Exception as e:
+ return f"Error: {e}", gr.update()
+
+ def auto_caption_single(
+ self, dataset_name: str, video_name: str, vlm_instruction: str, use_8bit: bool
+ ) -> tuple[str, str]:
+ """Auto-generate caption for a single video.
+
+ Args:
+ dataset_name: Name of the dataset
+ video_name: Name of the video file
+ vlm_instruction: Custom VLM instruction for captioning
+ use_8bit: Whether to use 8-bit quantization
+
+ Returns:
+ Tuple of (generated caption, status message)
+ """
+ if not dataset_name or not video_name:
+ return "", "No video selected"
+
+ try:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ # Use custom instruction if provided, otherwise use default
+ instruction = (
+ vlm_instruction if vlm_instruction and vlm_instruction.strip() else DEFAULT_VLM_CAPTION_INSTRUCTION
+ )
+
+ captioner = create_captioner(
+ captioner_type=CaptionerType.QWEN_25_VL,
+ use_8bit=use_8bit,
+ vlm_instruction=instruction,
+ device=device,
+ )
+
+ video_path = Path(self.dataset_manager.datasets_root) / dataset_name / "videos" / video_name
+ caption = captioner.caption(str(video_path))
+
+ self.dataset_manager.update_caption(dataset_name, video_name, caption)
+ return caption, f"Caption generated for {video_name}"
+ except Exception as e:
+ return "", f"Error: {e}"
+
+ def auto_caption_all_uncaptioned(self, dataset_name: str, vlm_instruction: str, use_8bit: bool) -> tuple[str, dict]:
+ """Auto-generate captions for all uncaptioned videos.
+
+ Args:
+ dataset_name: Name of the dataset
+ vlm_instruction: Custom VLM instruction for captioning
+ use_8bit: Whether to use 8-bit quantization
+
+ Returns:
+ Tuple of (status message, updated stats)
+ """
+ if not dataset_name:
+ return "Please select a dataset first", gr.update()
+
+ try:
+ items = self.dataset_manager.get_dataset_items(dataset_name)
+ uncaptioned = [i for i in items if not i.get("caption") or not i.get("caption").strip()]
+
+ if not uncaptioned:
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+ return "All videos already have captions", stats
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ # Use custom instruction if provided, otherwise use default
+ instruction = (
+ vlm_instruction if vlm_instruction and vlm_instruction.strip() else DEFAULT_VLM_CAPTION_INSTRUCTION
+ )
+
+ captioner = create_captioner(
+ captioner_type=CaptionerType.QWEN_25_VL,
+ use_8bit=use_8bit,
+ vlm_instruction=instruction,
+ device=device,
+ )
+
+ for _idx, item in enumerate(uncaptioned):
+ video_path = Path(self.dataset_manager.datasets_root) / dataset_name / item["media_path"]
+ caption = captioner.caption(str(video_path))
+ video_name = Path(item["media_path"]).name
+ self.dataset_manager.update_caption(dataset_name, video_name, caption)
+
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+ return f"Generated captions for {len(uncaptioned)} videos", stats
+ except Exception as e:
+ return f"Error: {e}", gr.update()
+
+ def validate_dataset_ui(self, dataset_name: str) -> dict:
+ """Validate a dataset and return issues.
+
+ Args:
+ dataset_name: Name of the dataset
+
+ Returns:
+ Validation results dictionary
+ """
+ if not dataset_name:
+ return {"error": "Please select a dataset first"}
+
+ try:
+ return self.dataset_manager.validate_dataset(dataset_name)
+ except Exception as e:
+ return {"error": str(e)}
+
+ def preprocess_dataset_ui(
+ self,
+ dataset_name: str,
+ width: int,
+ height: int,
+ num_frames: int,
+ id_token: str,
+ decode_videos: bool,
+ load_text_encoder_8bit: bool,
+ ) -> tuple[str, dict]:
+ """Preprocess a dataset from the UI.
+
+ Args:
+ dataset_name: Name of the dataset
+ width: Video width
+ height: Video height
+ num_frames: Number of frames
+ id_token: ID token to prepend to captions
+ decode_videos: Whether to decode videos for verification
+ load_text_encoder_8bit: Whether to load text encoder in 8-bit mode
+
+ Returns:
+ Tuple of (status message, updated stats)
+ """
+ if not dataset_name:
+ return "Please select a dataset first", gr.update()
+
+ try:
+ dataset_dir = self.dataset_manager.datasets_root / dataset_name
+ dataset_json = dataset_dir / "dataset.json"
+
+ if not dataset_json.exists():
+ return "Dataset JSON not found", gr.update()
+
+ # Check if dataset has reference videos
+ with open(dataset_json) as f:
+ dataset_items = json.load(f)
+
+ has_references = any(item.get("reference_path") for item in dataset_items)
+
+ resolution_buckets = f"{width}x{height}x{num_frames}"
+ parsed_buckets = parse_resolution_buckets(resolution_buckets)
+
+ preprocess_dataset(
+ dataset_file=str(dataset_json),
+ caption_column="caption",
+ video_column="media_path",
+ resolution_buckets=parsed_buckets,
+ batch_size=1,
+ output_dir=None, # Will use default .precomputed
+ id_token=id_token if id_token and id_token.strip() else None,
+ vae_tiling=False,
+ decode_videos=decode_videos,
+ model_source=LtxvModelVersion.latest(),
+ device="cuda" if torch.cuda.is_available() else "cpu",
+ load_text_encoder_in_8bit=load_text_encoder_8bit,
+ reference_column="reference_path" if has_references else None,
+ )
+
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+ message = "Preprocessing complete!"
+ if has_references:
+ message += " (IC-LoRA reference videos processed)"
+ if decode_videos:
+ decoded_dir = dataset_dir / ".precomputed" / "decoded_videos"
+ message += f"\nDecoded videos saved to: {decoded_dir}"
+ return message, stats
+ except Exception as e:
+ logger.error(f"Preprocessing failed: {e}", exc_info=True)
+ return f"Preprocessing failed: {e}", gr.update()
+
+ def delete_video_from_dataset(
+ self, dataset_name: str, video_name: str
+ ) -> tuple[str, gr.Gallery, None, str, str, dict]:
+ """Delete a video from a dataset.
+
+ Args:
+ dataset_name: Name of the dataset
+ video_name: Name of the video to delete
+
+ Returns:
+ Tuple of (status message, updated gallery, cleared video/name/caption, updated stats)
+ """
+ if not dataset_name or not video_name:
+ return "No video selected", gr.update(), None, "", "", gr.update()
+
+ try:
+ self.dataset_manager.delete_video(dataset_name, video_name)
+
+ # Refresh gallery
+ items = self.dataset_manager.get_dataset_items(dataset_name)
+ gallery_items = [
+ (
+ i["thumbnail"],
+ (
+ i.get("caption", "")[:50] + "..."
+ if len(i.get("caption", "")) > 50
+ else i.get("caption", "") or "No caption"
+ ),
+ )
+ for i in items
+ if i["thumbnail"]
+ ]
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+
+ return f"Deleted {video_name}", gr.update(value=gallery_items), None, "", "", stats
+ except Exception as e:
+ return f"Error: {e}", gr.update(), None, "", "", gr.update()
+
+ def split_scenes_and_add(
+ self,
+ dataset_name: str,
+ video_file: str,
+ detector_type: str,
+ min_scene_length: int | None,
+ threshold: float | None,
+ filter_shorter: str | None,
+ max_scenes: int | None,
+ ) -> tuple[str, gr.Gallery, dict]:
+ """Split a video into scenes and add them to the dataset.
+
+ Args:
+ dataset_name: Name of the dataset
+ video_file: Path to the video file
+ detector_type: Type of scene detector to use
+ min_scene_length: Minimum scene length in frames
+ threshold: Detection threshold
+ filter_shorter: Filter scenes shorter than this duration
+ max_scenes: Maximum number of scenes (0 for unlimited)
+
+ Returns:
+ Tuple of (status message, updated gallery, updated stats)
+ """
+ if not dataset_name:
+ return "Please select a dataset first", gr.update(), gr.update()
+
+ if not video_file:
+ return "Please upload a video file", gr.update(), gr.update()
+
+ try:
+ import tempfile
+
+ with tempfile.TemporaryDirectory() as temp_dir:
+ temp_path = Path(temp_dir)
+ video_path = Path(video_file)
+
+ scene_files = self.dataset_manager.split_video_scenes(
+ video_path=video_path,
+ output_dir=temp_path,
+ detector_type=detector_type,
+ min_scene_length=int(min_scene_length) if min_scene_length else None,
+ threshold=float(threshold) if threshold else None,
+ filter_shorter_than=filter_shorter if filter_shorter and filter_shorter.strip() else None,
+ max_scenes=int(max_scenes) if max_scenes and max_scenes > 0 else None,
+ )
+
+ if not scene_files:
+ return "No scenes detected", gr.update(), gr.update()
+
+ result = self.dataset_manager.add_videos(dataset_name, [str(f) for f in scene_files])
+
+ message = (
+ f"Split into {len(scene_files)} scenes using {detector_type} detector. "
+ f"Added {len(result['added'])} scenes to dataset."
+ )
+ if result["failed"]:
+ message += f" {len(result['failed'])} failed."
+
+ items = self.dataset_manager.get_dataset_items(dataset_name)
+ gallery_items = [
+ (
+ i["thumbnail"],
+ (
+ i.get("caption", "")[:50] + "..."
+ if len(i.get("caption", "")) > 50
+ else i.get("caption", "") or "No caption"
+ ),
+ )
+ for i in items
+ if i["thumbnail"]
+ ]
+ stats = self.dataset_manager.get_dataset_stats(dataset_name)
+
+ return message, gr.update(value=gallery_items), stats
+
+ except Exception as e:
+ logger.error(f"Scene splitting failed: {e}", exc_info=True)
+ return f"Error: {e}", gr.update(), gr.update()
+
+ def filter_dataset_gallery(self, dataset_name: str, search_text: str) -> list:
+ """Filter gallery by caption search.
+
+ Args:
+ dataset_name: Name of the dataset
+ search_text: Text to search for in captions
+
+ Returns:
+ Filtered gallery items
+ """
+ if not dataset_name:
+ return []
+
+ items = self.dataset_manager.get_dataset_items(dataset_name)
+
+ if search_text and search_text.strip():
+ search_lower = search_text.lower()
+ items = [i for i in items if search_lower in i.get("caption", "").lower()]
+
+ gallery_items = [
+ (
+ i["thumbnail"],
+ (
+ i.get("caption", "")[:50] + "..."
+ if len(i.get("caption", "")) > 50
+ else i.get("caption", "") or "No caption"
+ ),
+ )
+ for i in items
+ if i["thumbnail"]
+ ]
+
+ return gallery_items
+
+ def refresh_job_list(self, show_all: bool = False) -> list[list]:
+ """Get formatted list of jobs for display.
+
+ Args:
+ show_all: If True, show all jobs. If False, show only latest job.
+
+ Returns:
+ List of job rows for dataframe display
+ """
+ jobs = self.job_db.get_all_jobs()
+
+ # If not showing all, only show the most recent job
+ if not show_all and jobs:
+ jobs = [jobs[0]] # Jobs are already ordered by ID desc (most recent first)
+
+ # Format for dataframe
+ rows = []
+ for job in jobs:
+ rows.append(
+ [
+ job["id"],
+ job["status"],
+ job["dataset_name"],
+ job["created_at"][:19] if job.get("created_at") else "",
+ job.get("progress", "")[:50] if job.get("progress") else "",
+ ]
+ )
+
+ return rows
+
+ def get_latest_job_id(self) -> int:
+ """Get the ID of the latest job.
+
+ Returns:
+ Latest job ID or 0 if no jobs
+ """
+ jobs = self.job_db.get_all_jobs()
+ return jobs[0]["id"] if jobs else 0
+
+ def get_running_job(self) -> dict | None:
+ """Get the currently running job if any.
+
+ Returns:
+ Running job dict or None
+ """
+ jobs = self.job_db.get_all_jobs()
+ for job in jobs:
+ if job["status"] == JobStatus.RUNNING:
+ return job
+ return None
+
+ def get_current_job_display(self) -> tuple[str, str, str | None, str, str]:
+ """Get formatted display for the current running job.
+
+ Returns:
+ Tuple of (status_html, job_info, validation_sample, validation_prompt, logs)
+ """
+ running_job = self.get_running_job()
+
+ if not running_job:
+ return (
+ '
No job currently running
',
+ "",
+ None,
+ "",
+ "",
+ )
+
+ # Format job info
+ job_info = f"""**Job #{running_job["id"]}** - {running_job["dataset_name"]}
+**Status:** {running_job["status"]}
+**Progress:** {running_job.get("progress", "Starting...")}
+**Started:** {running_job.get("started_at", "N/A")[:19] if running_job.get("started_at") else "N/A"}
+"""
+
+ # Get validation sample
+ validation_sample = running_job.get("validation_sample")
+ if validation_sample and not Path(validation_sample).is_absolute():
+ validation_sample = str(PROJECT_ROOT / validation_sample)
+
+ # Get validation prompt from job params
+ validation_prompt = running_job.get("params", {}).get("validation_prompt", "")
+
+ # Get logs
+ logs = running_job.get("logs", "") or "Waiting for logs..."
+
+ # Status HTML with color coding
+ status_color = "#2563eb" # blue for running
+ status_html = (
+ f''
+ f'▶️ Training in Progress - Job #{running_job["id"]}'
+ f"
"
+ )
+
+ return status_html, job_info, validation_sample, validation_prompt, logs
+
+ def view_job_details(self, job_id: int | float) -> tuple[int, dict, str | None, str, gr.Accordion]:
+ """View job details by ID.
+
+ Args:
+ job_id: The job ID to view
+
+ Returns:
+ Tuple of (job_id, job_details, validation_sample, logs, accordion_update)
+ """
+ try:
+ if not job_id or job_id == 0:
+ return 0, {}, None, "Please enter a valid Job ID", gr.Accordion(open=False)
+
+ job = self.job_db.get_job(int(job_id))
+ if job:
+ logs = job.get("logs", "") or "No logs available yet"
+ validation_sample = job.get("validation_sample")
+ # Convert path to absolute if it's a relative path
+ if validation_sample and not Path(validation_sample).is_absolute():
+ validation_sample = str(PROJECT_ROOT / validation_sample)
+ return int(job_id), job, validation_sample, logs, gr.Accordion(open=True)
+ else:
+ return int(job_id), {}, None, f"Job #{int(job_id)} not found", gr.Accordion(open=True)
+ except (TypeError, ValueError) as e:
+ logger.error(f"Error viewing job details: {e}")
+ return 0, {}, None, f"Error: {e}", gr.Accordion(open=False)
+
+ def clear_completed_jobs(self) -> tuple[str, list[list]]:
+ """Clear all completed, failed, and cancelled jobs.
+
+ Returns:
+ Tuple of (status message, updated job list)
+ """
+ self.job_db.clear_completed_jobs()
+ return "✅ Cleared all completed jobs", self.refresh_job_list(show_all=True)
+
+ def stop_worker(self) -> tuple[str, list[list]]:
+ """Stop the worker and cancel the current running job.
+
+ Returns:
+ Tuple of (status message, updated job list)
+ """
+ try:
+ # First, cancel any running jobs
+ jobs = self.job_db.get_all_jobs()
+ cancelled_jobs = []
+ for job in jobs:
+ if job["status"] == JobStatus.RUNNING:
+ self.job_db.update_job_status(
+ job["id"], JobStatus.CANCELLED, error_message="Worker stopped by user"
+ )
+ cancelled_jobs.append(job["id"])
+
+ # Send shutdown signal to worker
+ db_path = PROJECT_ROOT / "jobs.db"
+ shutdown_signal_file = db_path.parent / ".worker_shutdown_signal"
+ shutdown_signal_file.touch()
+
+ if cancelled_jobs:
+ msg = f"✅ Cancelled job(s) #{', #'.join(map(str, cancelled_jobs))} and sent shutdown signal to worker."
+ else:
+ msg = "✅ Shutdown signal sent to worker. No running jobs to cancel."
+
+ return msg, self.refresh_job_list(show_all=True)
+ except Exception as e:
+ logger.error(f"Error stopping worker: {e}")
+ return f"❌ Error: {e}", self.refresh_job_list(show_all=True)
+
+ def start_worker(self) -> str:
+ """Start the worker process.
+
+ Returns:
+ Status message
+ """
+ try:
+ import subprocess
+ import sys
+
+ # Check if already running
+ status = self.check_worker_status()
+ if "running" in status.lower():
+ return f"⚠️ Worker is already running\n{status}"
+
+ # Remove shutdown signal if it exists
+ shutdown_signal_file = PROJECT_ROOT / ".worker_shutdown_signal"
+ if shutdown_signal_file.exists():
+ shutdown_signal_file.unlink()
+
+ # Start worker process
+ worker_script = PROJECT_ROOT / "scripts" / "jobs" / "run_worker.py"
+ python_exe = sys.executable
+
+ subprocess.Popen(
+ [python_exe, str(worker_script)],
+ cwd=str(PROJECT_ROOT),
+ stdout=subprocess.DEVNULL,
+ stderr=subprocess.DEVNULL,
+ start_new_session=True, # Detach from parent
+ )
+
+ return "✅ Worker started successfully"
+ except Exception as e:
+ logger.error(f"Error starting worker: {e}")
+ return f"❌ Error starting worker: {e}"
+
+ def check_worker_status(self) -> str:
+ """Check if worker is running.
+
+ Returns:
+ Worker status message
+ """
+ try:
+ import psutil
+
+ # Look for worker process
+ for proc in psutil.process_iter(["pid", "name", "cmdline"]):
+ try:
+ cmdline = proc.info.get("cmdline", [])
+ if cmdline and "run_worker.py" in " ".join(cmdline):
+ return f"✅ Worker is running (PID: {proc.info['pid']})"
+ except (psutil.NoSuchProcess, psutil.AccessDenied):
+ continue
+
+ return "⚠️ Worker is not running. Start it with: python scripts/jobs/run_worker.py"
+ except ImportError:
+ return "💡 Install psutil to check worker status: pip install psutil"
+ except Exception as e:
+ return f"❓ Could not determine worker status: {e}"
+
+ def create_ui(self) -> gr.Blocks: # noqa: PLR0915
+ """Create the Gradio UI."""
+ with gr.Blocks() as blocks:
+ gr.Markdown("# LTX-Video Trainer")
+
+ with gr.Tab("Training"):
+ gr.Markdown("# 🎬 Train LTX-Video LoRA")
+
+ # Dataset Selection
+ with gr.Group():
+ gr.Markdown("### 📁 Dataset")
+ dataset_for_training = gr.Dropdown(
+ choices=self.dataset_manager.list_datasets(),
+ label="Select Dataset",
+ interactive=True,
+ info="Choose a dataset you created in the Datasets tab",
+ )
+
+ # Training Configuration
+ with gr.Row():
+ with gr.Column(scale=1):
+ with gr.Group():
+ gr.Markdown("### ⚙️ Training Parameters")
+ model_source = gr.Dropdown(
+ choices=[str(v) for v in LtxvModelVersion],
+ value=str(LtxvModelVersion.latest()),
+ label="Model Version",
+ info="Select the model version to use for training",
+ )
+
+ with gr.Row():
+ lr = gr.Number(value=2e-4, label="Learning Rate", scale=1)
+ steps = gr.Number(
+ value=1500,
+ label="Training Steps",
+ precision=0,
+ info="Total training steps",
+ scale=1,
+ )
+
+ with gr.Row():
+ lora_rank = gr.Dropdown(
+ choices=list(range(8, 257, 8)),
+ value=128,
+ label="LoRA Rank",
+ info="Higher = more capacity",
+ scale=1,
+ )
+ batch_size = gr.Number(value=1, label="Batch Size", precision=0, scale=1)
+
+ validation_interval = gr.Number(
+ value=100,
+ label="Validation Interval",
+ precision=0,
+ info="Steps between validation samples",
+ minimum=1,
+ )
+
+ with gr.Group():
+ gr.Markdown("### 🎞️ Video Settings")
+ id_token = gr.Textbox(
+ label="LoRA ID Token",
+ placeholder="Optional: e.g., or ",
+ value="",
+ info="Prepended to training captions",
+ )
+
+ with gr.Row():
+ width = gr.Dropdown(
+ choices=list(range(256, 1025, 32)),
+ value=768,
+ label="Width",
+ info="Multiple of 32",
+ )
+ height = gr.Dropdown(
+ choices=list(range(256, 1025, 32)),
+ value=768,
+ label="Height",
+ info="Multiple of 32",
+ )
+ num_frames = gr.Dropdown(
+ choices=list(range(9, 129, 8)),
+ value=25,
+ label="Frames",
+ info="Multiple of 8",
+ )
+
+ with gr.Column(scale=1), gr.Group():
+ # Training Monitor
+ gr.Markdown("### 📊 Training Monitor")
+
+ with gr.Row():
+ train_btn = gr.Button("▶️ Start Training", variant="primary", size="lg", scale=2)
+ reset_btn = gr.Button("🔄 Reset", variant="secondary", size="lg", scale=1)
+
+ current_job_display = gr.Markdown(value="", visible=True, label="Current Job")
+
+ self.status_output = gr.Textbox(label="Status", interactive=False)
+ self.progress_output = gr.Textbox(label="Progress", interactive=False)
+
+ gr.Markdown("**Validation Sample**")
+ self.validation_prompt = gr.Textbox(
+ label="Validation Prompt",
+ placeholder="Enter the prompt to use for validation samples",
+ value="a professional portrait video of a person with blurry bokeh background",
+ interactive=True,
+ info="Prompt used to generate sample videos during training (view in Jobs tab)",
+ lines=2,
+ )
+
+ # Advanced Settings (Collapsible)
+ with gr.Accordion("🚀 HuggingFace Hub (Optional)", open=False):
+ push_to_hub = gr.Checkbox(
+ label="Push to HuggingFace Hub",
+ value=False,
+ info="Automatically upload trained model to HuggingFace Hub",
+ )
+ with gr.Row():
+ hf_token = gr.Textbox(
+ label="HuggingFace Token",
+ type="password",
+ placeholder="hf_...",
+ info="Your HuggingFace API token",
+ )
+ hf_model_id = gr.Textbox(
+ label="Model ID",
+ placeholder="username/model-name",
+ info="Repository name on HuggingFace",
+ )
+
+ # Results Section
+ with gr.Group():
+ gr.Markdown("### ✅ Training Results")
+ with gr.Row():
+ self.download_btn = gr.DownloadButton(
+ label="📥 Download LoRA Weights",
+ visible=False,
+ interactive=True,
+ size="lg",
+ )
+ self.hf_repo_link = gr.HTML(
+ value="",
+ visible=False,
+ label="HuggingFace Hub",
+ )
+
+ # Jobs Queue Link
+ gr.Markdown("---") # Separator
+ gr.Markdown("### 📋 Training Queue Status")
+ gr.Markdown(
+ "💡 **Tip:** Switch to the **Jobs** tab to monitor training progress, "
+ "view logs, and manage the queue."
+ )
+
+ # Jobs Tab
+ with gr.Tab("Jobs"):
+ gr.Markdown("# 📋 Training Jobs")
+
+ # Worker Status Section
+ with gr.Group():
+ gr.Markdown("## 🔧 Worker Status")
+ with gr.Row():
+ check_worker_btn = gr.Button("🔍 Check Status", size="sm", scale=1)
+ stop_worker_btn = gr.Button("🛑 Stop Worker", variant="stop", size="sm", scale=1)
+
+ queue_status = gr.Textbox(
+ label="",
+ interactive=False,
+ show_label=False,
+ container=True,
+ placeholder="Worker status will appear here...",
+ )
+
+ gr.Markdown("---")
+
+ # Current Running Job Section
+ with gr.Group():
+ gr.Markdown("## ▶️ Current Training Job")
+
+ current_job_status = gr.HTML(
+ value=(
+ ''
+ "No job currently running
"
+ )
+ )
+
+ with gr.Row():
+ with gr.Column(scale=2):
+ current_job_info = gr.Markdown(value="")
+
+ gr.Markdown("#### 🎬 Latest Validation Sample")
+ current_job_sample = gr.Video(
+ label="",
+ show_label=False,
+ interactive=False,
+ autoplay=True,
+ loop=True,
+ height=300,
+ )
+
+ gr.Markdown("#### 💬 Validation Prompt")
+ current_job_prompt = gr.Textbox(
+ label="",
+ show_label=False,
+ interactive=False,
+ lines=2,
+ placeholder="Validation prompt will appear here...",
+ )
+
+ with gr.Column(scale=3):
+ gr.Markdown("#### 📝 Training Logs (Live)")
+ current_job_logs = gr.Textbox(
+ label="",
+ lines=25,
+ max_lines=30,
+ interactive=False,
+ show_copy_button=True,
+ placeholder="Logs will appear here when training starts...",
+ autoscroll=True,
+ )
+
+ gr.Markdown("---")
+
+ # Job History Section
+ with gr.Group():
+ gr.Markdown("## 📚 Job History")
+
+ with gr.Row():
+ refresh_jobs_btn = gr.Button("🔄 Refresh Jobs", size="sm")
+ clear_completed_btn = gr.Button("🗑️ Clear Completed", size="sm", variant="secondary")
+
+ job_table = gr.Dataframe(
+ headers=["ID", "Status", "Dataset", "Created", "Progress"],
+ datatype=["number", "str", "str", "str", "str"],
+ label="All Jobs",
+ interactive=False,
+ wrap=True,
+ row_count=10,
+ )
+
+ # Manual job selection for viewing older jobs
+ with gr.Accordion("🔍 View Specific Job", open=False):
+ with gr.Row():
+ selected_job_id = gr.Number(
+ label="Job ID",
+ precision=0,
+ value=0,
+ interactive=True,
+ scale=1,
+ )
+ view_job_btn = gr.Button("View", size="sm", scale=1)
+
+ selected_job_details = gr.JSON(label="Job Configuration", value={})
+
+ gr.Markdown("#### Validation Sample")
+ selected_job_sample = gr.Video(
+ label="",
+ show_label=False,
+ interactive=False,
+ autoplay=True,
+ loop=True,
+ height=250,
+ )
+
+ gr.Markdown("#### Logs")
+ selected_job_logs = gr.Textbox(
+ label="",
+ lines=15,
+ interactive=False,
+ show_copy_button=True,
+ )
+
+ with gr.Accordion("💡 Worker Information", open=False):
+ gr.Markdown(
+ "### Worker Management\n\n"
+ "The worker processes training jobs from the queue.\n\n"
+ "**To start:** Click the '▶️ Start Worker' button above, or run:\n"
+ "```bash\n"
+ "python scripts/jobs/run_worker.py\n"
+ "```\n\n"
+ "**To stop:** Click the '🛑 Stop Worker' button above.\n\n"
+ "**To check status:** Click the '🔍 Check Worker Status' button."
+ )
+
+ gr.Textbox(
+ label="Worker Status",
+ interactive=False,
+ placeholder="Click 'Check Worker Status' to see if worker is running...",
+ )
+
+ # Event handlers
+
+ # Auto-refresh current job on timer
+ def auto_refresh_current_job() -> tuple[str, str, str | None, str, str, list, str]:
+ """Auto-refresh the current running job display and check worker status."""
+ status_html, job_info, validation_sample, validation_prompt, logs = self.get_current_job_display()
+ job_list = self.refresh_job_list(show_all=True)
+ worker_status = self.check_worker_status()
+ return status_html, job_info, validation_sample, validation_prompt, logs, job_list, worker_status
+
+ job_refresh_timer = gr.Timer(value=5)
+ job_refresh_timer.tick(
+ fn=auto_refresh_current_job,
+ outputs=[
+ current_job_status,
+ current_job_info,
+ current_job_sample,
+ current_job_prompt,
+ current_job_logs,
+ job_table,
+ queue_status,
+ ],
+ show_progress=False,
+ )
+
+ # Manual refresh
+ refresh_jobs_btn.click(
+ auto_refresh_current_job,
+ outputs=[
+ current_job_status,
+ current_job_info,
+ current_job_sample,
+ current_job_prompt,
+ current_job_logs,
+ job_table,
+ queue_status,
+ ],
+ )
+
+ # Clear completed jobs
+ clear_completed_btn.click(
+ self.clear_completed_jobs,
+ outputs=[queue_status, job_table],
+ )
+
+ # Worker control
+ check_worker_btn.click(
+ self.check_worker_status,
+ outputs=[queue_status],
+ )
+
+ stop_worker_btn.click(
+ self.stop_worker,
+ outputs=[queue_status, job_table],
+ )
+
+ # View specific job
+ def view_specific_job(job_id: int | float) -> tuple[dict, str | None, str]:
+ if not job_id or job_id == 0:
+ return {}, None, "Please enter a valid Job ID"
+ job_id_int, details, sample, logs, _ = self.view_job_details(job_id)
+ return details, sample, logs
+
+ view_job_btn.click(
+ view_specific_job,
+ inputs=[selected_job_id],
+ outputs=[selected_job_details, selected_job_sample, selected_job_logs],
+ )
+
+ selected_job_id.submit(
+ view_specific_job,
+ inputs=[selected_job_id],
+ outputs=[selected_job_details, selected_job_sample, selected_job_logs],
+ )
+
+ # Datasets Tab
+ with gr.Tab("Datasets"), gr.Row():
+ with gr.Column(scale=1):
+ # Dataset selection/creation
+ gr.Markdown("## Manage Datasets")
+
+ dataset_dropdown = gr.Dropdown(
+ choices=self.dataset_manager.list_datasets(),
+ label="Select Dataset",
+ interactive=True,
+ )
+
+ with gr.Row():
+ new_dataset_name = gr.Textbox(label="New Dataset Name", placeholder="my_dataset")
+ create_dataset_btn = gr.Button("Create Dataset", variant="primary")
+
+ # Dataset statistics
+ stats_box = gr.JSON(label="Dataset Statistics", value={})
+
+ # Captioning settings
+ gr.Markdown("### Caption Settings")
+
+ with gr.Accordion("🎨 VLM Captioning Options", open=False):
+ vlm_instruction = gr.Textbox(
+ label="Custom VLM Instruction",
+ placeholder="Leave empty to use default instruction",
+ value="",
+ lines=3,
+ info="Custom instruction for the vision-language model when generating captions",
+ )
+ vlm_use_8bit = gr.Checkbox(
+ label="Use 8-bit Quantization",
+ value=True,
+ info="Reduces VRAM usage during captioning",
+ )
+
+ # Batch operations
+ gr.Markdown("### Batch Operations")
+
+ with gr.Row():
+ auto_caption_btn = gr.Button("Auto-Caption All Uncaptioned", size="sm")
+ validate_btn = gr.Button("Validate Dataset", size="sm")
+
+ validation_result = gr.JSON(label="Validation Results", value={})
+
+ # Preprocessing options
+ gr.Markdown("### Preprocessing Options")
+
+ with gr.Accordion("⚙️ Advanced Preprocessing", open=False):
+ decode_videos_check = gr.Checkbox(
+ label="Decode Videos for Verification",
+ value=False,
+ info="Decode preprocessed latents to verify quality (slower, uses more disk space)",
+ )
+ load_text_encoder_8bit = gr.Checkbox(
+ label="Load Text Encoder in 8-bit",
+ value=False,
+ info="Reduces VRAM usage during text embedding computation",
+ )
+
+ with gr.Row():
+ preprocess_btn = gr.Button("Preprocess Dataset", variant="primary")
+
+ preprocess_status = gr.Textbox(label="Preprocessing Status", interactive=False)
+
+ with gr.Column(scale=3):
+ # Video upload area
+ gr.Markdown("## Upload Videos")
+
+ video_uploader = gr.File(
+ label="Drag and drop videos here",
+ file_count="multiple",
+ file_types=["video"],
+ type="filepath",
+ )
+
+ upload_btn = gr.Button("Add to Dataset")
+ upload_status = gr.Textbox(label="Upload Status", interactive=False)
+
+ # IC-LoRA reference videos
+ with gr.Accordion("🎯 IC-LoRA: Add Reference Videos (Optional)", open=False):
+ gr.Markdown(
+ "**IC-LoRA Training:** Upload reference videos paired with target videos "
+ "for advanced video-to-video transformations. "
+ "Reference videos must match target videos in count, resolution, and length."
+ )
+ reference_uploader = gr.File(
+ label="Reference Videos (for IC-LoRA) - Upload in same order as target videos",
+ file_count="multiple",
+ file_types=["video"],
+ type="filepath",
+ )
+ upload_with_ref_btn = gr.Button("Add Videos with References", variant="secondary")
+
+ # Scene splitting
+ with gr.Accordion("✂️ Split Long Videos into Scenes", open=False):
+ gr.Markdown(
+ "Upload a long video and automatically split it into individual scenes "
+ "using advanced detection algorithms"
+ )
+ scene_video_uploader = gr.File(
+ label="Video to Split",
+ file_count="single",
+ file_types=["video"],
+ type="filepath",
+ )
+ with gr.Row():
+ scene_detector_type = gr.Dropdown(
+ choices=["content", "adaptive", "threshold", "histogram"],
+ value="content",
+ label="Detection Algorithm",
+ info=(
+ "Content: HSV color changes | Adaptive: Two-phase cuts | "
+ "Threshold: Fades | Histogram: YUV changes"
+ ),
+ )
+ scene_threshold = gr.Number(
+ label="Detection Threshold",
+ value=27.0,
+ info="Lower = more sensitive",
+ )
+ with gr.Row():
+ scene_min_length = gr.Number(
+ label="Min Scene Length (frames)",
+ value=30,
+ precision=0,
+ info="Minimum scene duration during detection",
+ )
+ scene_filter_shorter = gr.Textbox(
+ label="Filter Shorter Than",
+ value="2s",
+ info="Filter scenes shorter than duration (e.g., '2s', '60' frames)",
+ )
+ scene_max_scenes = gr.Number(
+ label="Max Scenes (0=unlimited)",
+ value=0,
+ precision=0,
+ info="Maximum number of scenes to produce",
+ )
+ split_scenes_btn = gr.Button("Split Scenes and Add to Dataset", variant="primary")
+
+ # Visual dataset browser
+ gr.Markdown("## Dataset Browser")
+
+ with gr.Row():
+ search_box = gr.Textbox(
+ label="Search captions", placeholder="Filter by caption text...", scale=4
+ )
+ refresh_btn = gr.Button("🔄 Refresh", size="sm", scale=1)
+
+ # Video gallery with captions
+ dataset_gallery = gr.Gallery(
+ label="Videos",
+ columns=4,
+ height="auto",
+ object_fit="contain",
+ allow_preview=True,
+ )
+
+ # Selected video editor
+ gr.Markdown("### Edit Selected Video")
+
+ with gr.Row():
+ selected_video = gr.Video(label="Preview", interactive=False)
+
+ with gr.Column():
+ selected_video_name = gr.Textbox(label="Video Name", interactive=False)
+
+ caption_editor = gr.Textbox(
+ label="Caption",
+ placeholder="Enter caption for this video...",
+ lines=3,
+ interactive=True,
+ )
+
+ with gr.Row():
+ save_caption_btn = gr.Button("Save Caption", variant="primary")
+ generate_caption_btn = gr.Button("Auto-Generate")
+ delete_video_btn = gr.Button("Delete Video", variant="stop")
+
+ # Event handlers
+ # Update HF fields visibility based on push_to_hub checkbox
+ push_to_hub.change(
+ lambda x: {
+ hf_token: gr.update(visible=x),
+ hf_model_id: gr.update(visible=x),
+ },
+ inputs=[push_to_hub],
+ outputs=[hf_token, hf_model_id],
+ )
+
+ train_btn.click(
+ lambda dataset_name,
+ validation_prompt,
+ lr,
+ steps,
+ lora_rank,
+ batch_size,
+ model_source,
+ width,
+ height,
+ num_frames,
+ push_to_hub,
+ hf_model_id,
+ hf_token,
+ id_token,
+ validation_interval: self.start_training(
+ TrainingParams(
+ dataset_name=dataset_name,
+ validation_prompt=validation_prompt,
+ learning_rate=lr,
+ steps=steps,
+ lora_rank=lora_rank,
+ batch_size=batch_size,
+ model_source=model_source,
+ width=width,
+ height=height,
+ num_frames=num_frames,
+ push_to_hub=push_to_hub,
+ hf_model_id=hf_model_id,
+ hf_token=hf_token,
+ id_token=id_token,
+ validation_interval=validation_interval,
+ )
+ ),
+ inputs=[
+ dataset_for_training,
+ self.validation_prompt,
+ lr,
+ steps,
+ lora_rank,
+ batch_size,
+ model_source,
+ width,
+ height,
+ num_frames,
+ push_to_hub,
+ hf_model_id,
+ hf_token,
+ id_token,
+ validation_interval,
+ ],
+ outputs=[self.status_output, train_btn],
+ )
+
+ # Update timer to use class method
+ timer = gr.Timer(value=10) # 10 second interval
+ timer.tick(
+ fn=self.update_progress,
+ inputs=None,
+ outputs=[
+ current_job_display,
+ self.status_output,
+ self.download_btn,
+ self.hf_repo_link,
+ self.hf_repo_link,
+ ],
+ show_progress=False,
+ )
+
+ # Handle download button click
+ self.download_btn.click(self.get_model_path, inputs=None, outputs=[self.download_btn])
+
+ # Handle reset button click
+ reset_btn.click(
+ self.reset_interface,
+ inputs=None,
+ outputs=[
+ self.validation_prompt,
+ self.status_output,
+ self.progress_output,
+ self.download_btn,
+ self.hf_repo_link,
+ ],
+ )
+
+ # Dataset event handlers
+ create_dataset_btn.click(
+ self.create_new_dataset,
+ inputs=[new_dataset_name],
+ outputs=[
+ upload_status,
+ dataset_dropdown,
+ dataset_gallery,
+ stats_box,
+ selected_video,
+ selected_video_name,
+ caption_editor,
+ dataset_for_training, # Update training tab dropdown
+ ],
+ )
+
+ dataset_dropdown.change(
+ self.load_dataset,
+ inputs=[dataset_dropdown],
+ outputs=[
+ dataset_dropdown,
+ dataset_gallery,
+ stats_box,
+ selected_video,
+ selected_video_name,
+ caption_editor,
+ ],
+ )
+
+ upload_btn.click(
+ self.upload_videos_to_dataset,
+ inputs=[video_uploader, dataset_dropdown],
+ outputs=[upload_status, dataset_gallery, stats_box],
+ )
+
+ upload_with_ref_btn.click(
+ self.upload_videos_with_references,
+ inputs=[video_uploader, reference_uploader, dataset_dropdown],
+ outputs=[upload_status, dataset_gallery, stats_box],
+ )
+
+ split_scenes_btn.click(
+ self.split_scenes_and_add,
+ inputs=[
+ dataset_dropdown,
+ scene_video_uploader,
+ scene_detector_type,
+ scene_min_length,
+ scene_threshold,
+ scene_filter_shorter,
+ scene_max_scenes,
+ ],
+ outputs=[upload_status, dataset_gallery, stats_box],
+ )
+
+ dataset_gallery.select(
+ self.select_video_from_gallery,
+ inputs=[dataset_dropdown],
+ outputs=[selected_video, selected_video_name, caption_editor],
+ )
+
+ save_caption_btn.click(
+ self.save_caption_edit,
+ inputs=[dataset_dropdown, selected_video_name, caption_editor],
+ outputs=[upload_status, stats_box],
+ )
+
+ generate_caption_btn.click(
+ self.auto_caption_single,
+ inputs=[dataset_dropdown, selected_video_name, vlm_instruction, vlm_use_8bit],
+ outputs=[caption_editor, upload_status],
+ )
+
+ auto_caption_btn.click(
+ self.auto_caption_all_uncaptioned,
+ inputs=[dataset_dropdown, vlm_instruction, vlm_use_8bit],
+ outputs=[preprocess_status, stats_box],
+ )
+
+ validate_btn.click(self.validate_dataset_ui, inputs=[dataset_dropdown], outputs=[validation_result])
+
+ preprocess_btn.click(
+ self.preprocess_dataset_ui,
+ inputs=[
+ dataset_dropdown,
+ width,
+ height,
+ num_frames,
+ id_token,
+ decode_videos_check,
+ load_text_encoder_8bit,
+ ],
+ outputs=[preprocess_status, stats_box],
+ )
+
+ refresh_btn.click(
+ self.load_dataset,
+ inputs=[dataset_dropdown],
+ outputs=[
+ dataset_dropdown,
+ dataset_gallery,
+ stats_box,
+ selected_video,
+ selected_video_name,
+ caption_editor,
+ ],
+ )
+
+ delete_video_btn.click(
+ self.delete_video_from_dataset,
+ inputs=[dataset_dropdown, selected_video_name],
+ outputs=[
+ upload_status,
+ dataset_gallery,
+ selected_video,
+ selected_video_name,
+ caption_editor,
+ stats_box,
+ ],
+ )
+
+ search_box.change(
+ self.filter_dataset_gallery, inputs=[dataset_dropdown, search_box], outputs=[dataset_gallery]
+ )
+
+ return blocks
+
+
+def main() -> None:
+ """Main entry point."""
+ ui = GradioUI()
+ demo = ui.create_ui()
+ demo.queue()
+ demo.launch(server_name="0.0.0.0", server_port=7860)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/dataset_manager.py b/scripts/dataset_manager.py
new file mode 100644
index 0000000..da45d17
--- /dev/null
+++ b/scripts/dataset_manager.py
@@ -0,0 +1,400 @@
+"""Dataset Manager for LTX-Video-Trainer.
+
+This module provides utilities for managing video datasets, including:
+- Creating and organizing datasets
+- Adding videos with automatic thumbnail generation
+- Managing captions
+- Dataset validation
+"""
+
+import json
+import shutil
+from pathlib import Path
+from typing import Optional
+
+import cv2
+
+
+class DatasetManager:
+ """Manager for video dataset operations."""
+
+ def __init__(self, datasets_root: Path = Path(__file__).parent.parent / "datasets"):
+ """Initialize the dataset manager.
+
+ Args:
+ datasets_root: Root directory for all datasets
+ """
+ self.datasets_root = datasets_root
+ self.datasets_root.mkdir(exist_ok=True)
+
+ def list_datasets(self) -> list[str]:
+ """Get all dataset names.
+
+ Returns:
+ List of dataset names
+ """
+ return [d.name for d in self.datasets_root.iterdir() if d.is_dir()]
+
+ def create_dataset(self, name: str) -> Path:
+ """Create new dataset directory structure.
+
+ Args:
+ name: Name of the dataset
+
+ Returns:
+ Path to the created dataset directory
+
+ Raises:
+ ValueError: If dataset name is invalid or already exists
+ """
+ if not name or not name.strip():
+ raise ValueError("Dataset name cannot be empty")
+
+ safe_name = "".join(c for c in name if c.isalnum() or c in ("-", "_")).strip()
+ if not safe_name:
+ raise ValueError("Dataset name must contain alphanumeric characters")
+
+ dataset_dir = self.datasets_root / safe_name
+
+ if dataset_dir.exists():
+ raise ValueError(f"Dataset '{safe_name}' already exists")
+
+ (dataset_dir / "videos").mkdir(parents=True, exist_ok=True)
+ (dataset_dir / "thumbnails").mkdir(parents=True, exist_ok=True)
+
+ dataset_json = dataset_dir / "dataset.json"
+ with open(dataset_json, "w") as f:
+ json.dump([], f)
+
+ return dataset_dir
+
+ def add_videos(
+ self, dataset_name: str, video_files: list[str], reference_files: Optional[list[str]] = None
+ ) -> dict:
+ """Add videos to dataset and generate thumbnails.
+
+ Args:
+ dataset_name: Name of the dataset
+ video_files: List of video file paths to add
+ reference_files: Optional list of reference video files for IC-LoRA (must match video_files length)
+
+ Returns:
+ Dictionary with 'added' and 'failed' lists
+ """
+ dataset_dir = self.datasets_root / dataset_name
+ if not dataset_dir.exists():
+ raise ValueError(f"Dataset '{dataset_name}' does not exist")
+
+ videos_dir = dataset_dir / "videos"
+ thumbs_dir = dataset_dir / "thumbnails"
+
+ if reference_files:
+ references_dir = dataset_dir / "references"
+ references_dir.mkdir(exist_ok=True)
+
+ if len(reference_files) != len(video_files):
+ raise ValueError("Number of reference videos must match number of videos")
+
+ results = {"added": [], "failed": []}
+
+ dataset_json = dataset_dir / "dataset.json"
+ with open(dataset_json) as f:
+ items = json.load(f)
+
+ for idx, video_file in enumerate(video_files):
+ try:
+ video_path = Path(video_file)
+ dest_path = videos_dir / video_path.name
+
+ if dest_path.exists():
+ results["failed"].append({"video": video_path.name, "error": "Already exists"})
+ continue
+
+ shutil.copy2(video_file, dest_path)
+
+ thumb_path = self._generate_thumbnail(dest_path, thumbs_dir)
+
+ entry = {"media_path": f"videos/{video_path.name}", "caption": ""}
+
+ if reference_files and idx < len(reference_files):
+ ref_path = Path(reference_files[idx])
+ ref_dest_path = references_dir / ref_path.name
+ shutil.copy2(reference_files[idx], ref_dest_path)
+ entry["reference_path"] = f"references/{ref_path.name}"
+
+ items.append(entry)
+
+ results["added"].append({"video": video_path.name, "thumbnail": thumb_path.name})
+
+ except Exception as e:
+ results["failed"].append({"video": Path(video_file).name, "error": str(e)})
+
+ with open(dataset_json, "w") as f:
+ json.dump(items, f, indent=2)
+
+ return results
+
+ def _generate_thumbnail(self, video_path: Path, output_dir: Path) -> Path:
+ """Extract first frame as thumbnail.
+
+ Args:
+ video_path: Path to the video file
+ output_dir: Directory to save thumbnail
+
+ Returns:
+ Path to the generated thumbnail
+
+ Raises:
+ ValueError: If video cannot be read
+ """
+ cap = cv2.VideoCapture(str(video_path))
+
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ if frame_count > 10:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count // 2)
+
+ ret, frame = cap.read()
+ cap.release()
+
+ if not ret:
+ raise ValueError(f"Could not read video: {video_path}")
+
+ thumb_path = output_dir / f"{video_path.stem}.jpg"
+
+ height, width = frame.shape[:2]
+ target_size = 256
+
+ if width > height:
+ new_width = target_size
+ new_height = int(height * (target_size / width))
+ else:
+ new_height = target_size
+ new_width = int(width * (target_size / height))
+
+ frame = cv2.resize(frame, (new_width, new_height), interpolation=cv2.INTER_AREA)
+
+ if new_width != new_height:
+ max_dim = max(new_width, new_height)
+ square_frame = cv2.copyMakeBorder(
+ frame,
+ (max_dim - new_height) // 2,
+ (max_dim - new_height + 1) // 2,
+ (max_dim - new_width) // 2,
+ (max_dim - new_width + 1) // 2,
+ cv2.BORDER_CONSTANT,
+ value=(0, 0, 0),
+ )
+ frame = square_frame
+
+ cv2.imwrite(str(thumb_path), frame)
+ return thumb_path
+
+ def get_dataset_items(self, dataset_name: str) -> list[dict]:
+ """Get all videos and captions in dataset.
+
+ Args:
+ dataset_name: Name of the dataset
+
+ Returns:
+ List of dictionaries with 'media_path', 'caption', and 'thumbnail'
+ """
+ dataset_dir = self.datasets_root / dataset_name
+ if not dataset_dir.exists():
+ return []
+
+ dataset_json = dataset_dir / "dataset.json"
+
+ if dataset_json.exists():
+ with open(dataset_json) as f:
+ items = json.load(f)
+ else:
+ items = []
+
+ for item in items:
+ video_name = Path(item["media_path"]).name
+ thumb_path = dataset_dir / "thumbnails" / f"{Path(video_name).stem}.jpg"
+ item["thumbnail"] = str(thumb_path) if thumb_path.exists() else None
+ item["full_video_path"] = str(dataset_dir / item["media_path"])
+
+ return items
+
+ def update_caption(self, dataset_name: str, video_name: str, caption: str) -> None:
+ """Update caption for a specific video.
+
+ Args:
+ dataset_name: Name of the dataset
+ video_name: Name of the video file
+ caption: New caption text
+ """
+ dataset_dir = self.datasets_root / dataset_name
+ dataset_json = dataset_dir / "dataset.json"
+
+ with open(dataset_json) as f:
+ items = json.load(f)
+
+ found = False
+ for item in items:
+ if Path(item["media_path"]).name == video_name:
+ item["caption"] = caption
+ found = True
+ break
+
+ if not found:
+ items.append({"media_path": f"videos/{video_name}", "caption": caption})
+
+ with open(dataset_json, "w") as f:
+ json.dump(items, f, indent=2)
+
+ def delete_video(self, dataset_name: str, video_name: str) -> None:
+ """Remove video from dataset.
+
+ Args:
+ dataset_name: Name of the dataset
+ video_name: Name of the video file to delete
+ """
+ dataset_dir = self.datasets_root / dataset_name
+
+ dataset_json = dataset_dir / "dataset.json"
+ with open(dataset_json) as f:
+ items = json.load(f)
+
+ items = [i for i in items if Path(i["media_path"]).name != video_name]
+
+ with open(dataset_json, "w") as f:
+ json.dump(items, f, indent=2)
+
+ (dataset_dir / "videos" / video_name).unlink(missing_ok=True)
+ (dataset_dir / "thumbnails" / f"{Path(video_name).stem}.jpg").unlink(missing_ok=True)
+
+ def get_dataset_stats(self, dataset_name: str) -> dict:
+ """Get statistics about dataset.
+
+ Args:
+ dataset_name: Name of the dataset
+
+ Returns:
+ Dictionary with dataset statistics
+ """
+ dataset_dir = self.datasets_root / dataset_name
+ items = self.get_dataset_items(dataset_name)
+
+ has_references = any(i.get("reference_path") for i in items)
+ precomputed_dir = dataset_dir / ".precomputed"
+ has_reference_latents = (precomputed_dir / "reference_latents").exists() if precomputed_dir.exists() else False
+
+ return {
+ "name": dataset_name,
+ "total_videos": len(items),
+ "captioned": sum(1 for i in items if i.get("caption") and i.get("caption").strip()),
+ "uncaptioned": sum(1 for i in items if not i.get("caption") or not i.get("caption").strip()),
+ "preprocessed": precomputed_dir.exists(),
+ "has_references": has_references,
+ "reference_latents_computed": has_reference_latents,
+ }
+
+ def validate_dataset(self, dataset_name: str) -> dict:
+ """Validate dataset and return issues.
+
+ Args:
+ dataset_name: Name of the dataset
+
+ Returns:
+ Dictionary with validation results
+ """
+ issues = []
+ warnings = []
+
+ items = self.get_dataset_items(dataset_name)
+ dataset_dir = self.datasets_root / dataset_name
+
+ if not items:
+ issues.append("Dataset is empty")
+ return {"valid": False, "issues": issues, "warnings": warnings, "total_videos": 0}
+
+ uncaptioned = [i for i in items if not i.get("caption") or not i.get("caption").strip()]
+ if uncaptioned:
+ warnings.append(f"{len(uncaptioned)} videos without captions")
+
+ for item in items:
+ video_path = dataset_dir / item["media_path"]
+ if not video_path.exists():
+ issues.append(f"Missing video: {item['media_path']}")
+
+ if not (dataset_dir / ".precomputed").exists():
+ warnings.append("Dataset not preprocessed - will be slower to train")
+
+ for item in items[:5]:
+ video_path = dataset_dir / item["media_path"]
+ if video_path.exists():
+ cap = cv2.VideoCapture(str(video_path))
+ if not cap.isOpened():
+ issues.append(f"Cannot read video: {video_path.name}")
+ cap.release()
+
+ return {
+ "valid": len(issues) == 0,
+ "issues": issues,
+ "warnings": warnings,
+ "total_videos": len(items),
+ }
+
+ def split_video_scenes(
+ self,
+ video_path: Path,
+ output_dir: Path,
+ min_scene_length: Optional[int] = None,
+ threshold: Optional[float] = None,
+ detector_type: str = "content",
+ max_scenes: Optional[int] = None,
+ filter_shorter_than: Optional[str] = None,
+ ) -> list[Path]:
+ """Split a video into scenes using advanced scene detection.
+
+ Args:
+ video_path: Path to the video file
+ output_dir: Directory to save split scenes
+ min_scene_length: Minimum scene length in frames
+ threshold: Detection threshold
+ detector_type: Type of detector ('content', 'adaptive', 'threshold', 'histogram')
+ max_scenes: Maximum number of scenes to detect
+ filter_shorter_than: Filter scenes shorter than duration (e.g., "2s", "30")
+
+ Returns:
+ List of paths to split scene files
+ """
+ from scripts.split_scenes import DetectorType, detect_and_split_scenes
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ detector_map = {
+ "content": DetectorType.CONTENT,
+ "adaptive": DetectorType.ADAPTIVE,
+ "threshold": DetectorType.THRESHOLD,
+ "histogram": DetectorType.HISTOGRAM,
+ }
+
+ detector = detector_map.get(detector_type.lower(), DetectorType.CONTENT)
+
+ detect_and_split_scenes(
+ video_path=str(video_path),
+ output_dir=output_dir,
+ detector_type=detector,
+ threshold=threshold,
+ min_scene_len=min_scene_length,
+ max_scenes=max_scenes,
+ filter_shorter_than=filter_shorter_than,
+ save_images_per_scene=0,
+ )
+
+ scene_files = sorted(output_dir.glob(f"{video_path.stem}-Scene-*.mp4"))
+ return scene_files
+
+ def delete_dataset(self, dataset_name: str) -> None:
+ """Delete entire dataset.
+
+ Args:
+ dataset_name: Name of the dataset to delete
+ """
+ dataset_dir = self.datasets_root / dataset_name
+ if dataset_dir.exists():
+ shutil.rmtree(dataset_dir)
diff --git a/scripts/jobs/__init__.py b/scripts/jobs/__init__.py
new file mode 100644
index 0000000..fdb67a1
--- /dev/null
+++ b/scripts/jobs/__init__.py
@@ -0,0 +1,6 @@
+"""Job management system for training jobs."""
+
+from scripts.jobs.database import JobDatabase, JobStatus
+from scripts.jobs.worker import QueueWorker
+
+__all__ = ["JobDatabase", "JobStatus", "QueueWorker"]
diff --git a/scripts/jobs/database.py b/scripts/jobs/database.py
new file mode 100644
index 0000000..1deb9b3
--- /dev/null
+++ b/scripts/jobs/database.py
@@ -0,0 +1,314 @@
+"""SQLite database for persistent job queue management."""
+
+import json
+import sqlite3
+import threading
+from datetime import datetime, timezone
+from enum import Enum
+from pathlib import Path
+from typing import Any, Optional
+
+
+class JobStatus(str, Enum):
+ """Job status enumeration."""
+
+ PENDING = "pending"
+ RUNNING = "running"
+ COMPLETED = "completed"
+ FAILED = "failed"
+ CANCELLED = "cancelled"
+
+
+class JobDatabase:
+ """SQLite database for managing training job queue."""
+
+ def __init__(self, db_path: Path):
+ """Initialize the job database.
+
+ Args:
+ db_path: Path to the SQLite database file
+ """
+ self.db_path = db_path
+ self.db_path.parent.mkdir(parents=True, exist_ok=True)
+ self._local = threading.local()
+ self._init_db()
+
+ def _get_connection(self) -> sqlite3.Connection:
+ """Get a thread-local database connection."""
+ if not hasattr(self._local, "conn") or self._local.conn is None:
+ self._local.conn = sqlite3.connect(str(self.db_path), check_same_thread=False)
+ self._local.conn.row_factory = sqlite3.Row
+ return self._local.conn
+
+ def _init_db(self) -> None:
+ """Initialize database schema."""
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS jobs (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ status TEXT NOT NULL,
+ dataset_name TEXT NOT NULL,
+ params TEXT NOT NULL,
+ progress TEXT,
+ error_message TEXT,
+ logs TEXT,
+ created_at TEXT NOT NULL,
+ started_at TEXT,
+ completed_at TEXT,
+ process_id INTEGER,
+ output_dir TEXT,
+ checkpoint_path TEXT,
+ validation_sample TEXT
+ )
+ """
+ )
+
+ cursor.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_status ON jobs(status)
+ """
+ )
+
+ conn.commit()
+
+ def create_job(
+ self,
+ dataset_name: str,
+ params: dict[str, Any],
+ ) -> int:
+ """Create a new training job.
+
+ Args:
+ dataset_name: Name of the dataset to train on
+ params: Training parameters dictionary
+
+ Returns:
+ Job ID
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ INSERT INTO jobs (
+ status, dataset_name, params, created_at
+ ) VALUES (?, ?, ?, ?)
+ """,
+ (JobStatus.PENDING, dataset_name, json.dumps(params), datetime.now(timezone.utc).isoformat()),
+ )
+
+ conn.commit()
+ return cursor.lastrowid
+
+ def get_job(self, job_id: int) -> Optional[dict[str, Any]]:
+ """Get job by ID.
+
+ Args:
+ job_id: Job ID
+
+ Returns:
+ Job dictionary or None if not found
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute("SELECT * FROM jobs WHERE id = ?", (job_id,))
+ row = cursor.fetchone()
+
+ if row is None:
+ return None
+
+ return self._row_to_dict(row)
+
+ def get_all_jobs(self, status: Optional[JobStatus] = None) -> list[dict[str, Any]]:
+ """Get all jobs, optionally filtered by status.
+
+ Args:
+ status: Filter by status (optional)
+
+ Returns:
+ List of job dictionaries
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ if status:
+ cursor.execute("SELECT * FROM jobs WHERE status = ? ORDER BY created_at DESC", (status,))
+ else:
+ cursor.execute("SELECT * FROM jobs ORDER BY created_at DESC")
+
+ rows = cursor.fetchall()
+ return [self._row_to_dict(row) for row in rows]
+
+ def get_next_pending_job(self) -> Optional[dict[str, Any]]:
+ """Get the next pending job (oldest first).
+
+ Returns:
+ Job dictionary or None if no pending jobs
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ SELECT * FROM jobs
+ WHERE status = ?
+ ORDER BY created_at ASC
+ LIMIT 1
+ """,
+ (JobStatus.PENDING,),
+ )
+
+ row = cursor.fetchone()
+ if row is None:
+ return None
+
+ return self._row_to_dict(row)
+
+ def update_job_status(
+ self,
+ job_id: int,
+ status: JobStatus,
+ progress: Optional[str] = None,
+ error_message: Optional[str] = None,
+ logs: Optional[str] = None,
+ process_id: Optional[int] = None,
+ output_dir: Optional[str] = None,
+ checkpoint_path: Optional[str] = None,
+ validation_sample: Optional[str] = None,
+ ) -> None:
+ """Update job status and related fields.
+
+ Args:
+ job_id: Job ID
+ status: New status
+ progress: Progress message (optional)
+ error_message: Error message if failed (optional)
+ logs: Training logs (optional)
+ process_id: Process ID if running (optional)
+ output_dir: Output directory path (optional)
+ checkpoint_path: Path to final checkpoint (optional)
+ validation_sample: Path to latest validation sample (optional)
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ updates = ["status = ?"]
+ values = [status]
+
+ if progress is not None:
+ updates.append("progress = ?")
+ values.append(progress)
+
+ if error_message is not None:
+ updates.append("error_message = ?")
+ values.append(error_message)
+
+ if logs is not None:
+ updates.append("logs = ?")
+ values.append(logs)
+
+ if process_id is not None:
+ updates.append("process_id = ?")
+ values.append(process_id)
+
+ if output_dir is not None:
+ updates.append("output_dir = ?")
+ values.append(output_dir)
+
+ if checkpoint_path is not None:
+ updates.append("checkpoint_path = ?")
+ values.append(checkpoint_path)
+
+ if validation_sample is not None:
+ updates.append("validation_sample = ?")
+ values.append(validation_sample)
+
+ if status == JobStatus.RUNNING and not self.get_job(job_id).get("started_at"):
+ updates.append("started_at = ?")
+ values.append(datetime.now(timezone.utc).isoformat())
+ elif status in (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED):
+ updates.append("completed_at = ?")
+ values.append(datetime.now(timezone.utc).isoformat())
+
+ values.append(job_id)
+
+ cursor.execute(
+ f"""
+ UPDATE jobs
+ SET {", ".join(updates)}
+ WHERE id = ?
+ """,
+ values,
+ )
+
+ conn.commit()
+
+ def delete_job(self, job_id: int) -> None:
+ """Delete a job from the database.
+
+ Args:
+ job_id: Job ID
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute("DELETE FROM jobs WHERE id = ?", (job_id,))
+ conn.commit()
+
+ def get_job_count(self, status: Optional[JobStatus] = None) -> int:
+ """Get count of jobs, optionally filtered by status.
+
+ Args:
+ status: Filter by status (optional)
+
+ Returns:
+ Job count
+ """
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ if status:
+ cursor.execute("SELECT COUNT(*) FROM jobs WHERE status = ?", (status,))
+ else:
+ cursor.execute("SELECT COUNT(*) FROM jobs")
+
+ return cursor.fetchone()[0]
+
+ def clear_completed_jobs(self) -> None:
+ """Clear all completed, failed, and cancelled jobs."""
+ conn = self._get_connection()
+ cursor = conn.cursor()
+
+ cursor.execute(
+ """
+ DELETE FROM jobs
+ WHERE status IN (?, ?, ?)
+ """,
+ (JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED),
+ )
+
+ conn.commit()
+
+ def _row_to_dict(self, row: sqlite3.Row) -> dict[str, Any]:
+ """Convert SQLite row to dictionary.
+
+ Args:
+ row: SQLite row
+
+ Returns:
+ Dictionary representation
+ """
+ data = dict(row)
+ if data.get("params"):
+ data["params"] = json.loads(data["params"])
+ return data
+
+ def close(self) -> None:
+ """Close database connection."""
+ if hasattr(self._local, "conn") and self._local.conn is not None:
+ self._local.conn.close()
+ self._local.conn = None
diff --git a/scripts/jobs/run_worker.py b/scripts/jobs/run_worker.py
new file mode 100755
index 0000000..ac216f9
--- /dev/null
+++ b/scripts/jobs/run_worker.py
@@ -0,0 +1,17 @@
+#!/usr/bin/env python3
+"""
+Start the queue worker to process training jobs.
+
+This script starts a background worker that continuously polls the job database
+and executes pending training jobs.
+"""
+
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from scripts.jobs.worker import main
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/jobs/worker.py b/scripts/jobs/worker.py
new file mode 100644
index 0000000..47d3511
--- /dev/null
+++ b/scripts/jobs/worker.py
@@ -0,0 +1,299 @@
+"""Queue worker that processes training jobs."""
+
+import logging
+import os
+import signal
+import subprocess
+import sys
+import time
+from pathlib import Path
+from typing import Optional
+
+from scripts.jobs.database import JobDatabase, JobStatus
+
+logger = logging.getLogger(__name__)
+
+
+class QueueWorker:
+ """Worker that processes training jobs from the queue."""
+
+ def __init__(self, db_path: Path, check_interval: int = 5):
+ """Initialize the queue worker.
+
+ Args:
+ db_path: Path to the job database
+ check_interval: Seconds between checking for new jobs
+ """
+ self.db = JobDatabase(db_path)
+ self.check_interval = check_interval
+ self.running = False
+ self.current_process: Optional[subprocess.Popen] = None
+ self.current_job_id: Optional[int] = None
+
+ self.shutdown_signal_file = db_path.parent / ".worker_shutdown_signal"
+ if self.shutdown_signal_file.exists():
+ self.shutdown_signal_file.unlink()
+
+ signal.signal(signal.SIGINT, self._handle_shutdown)
+ signal.signal(signal.SIGTERM, self._handle_shutdown)
+
+ def _handle_shutdown(self, signum: int, _frame: object) -> None:
+ """Handle shutdown signals gracefully."""
+ logger.info(f"Received signal {signum}, shutting down gracefully...")
+ self.stop()
+
+ def start(self) -> None:
+ """Start the worker loop."""
+ self.running = True
+ logger.info("Queue worker started")
+
+ while self.running:
+ try:
+ if self.shutdown_signal_file.exists():
+ logger.info("Shutdown signal detected, stopping worker...")
+ self.stop()
+ break
+
+ job = self.db.get_next_pending_job()
+
+ if job:
+ self._process_job(job)
+ else:
+ time.sleep(self.check_interval)
+
+ except Exception as e:
+ logger.error(f"Error in worker loop: {e}", exc_info=True)
+ time.sleep(self.check_interval)
+
+ logger.info("Queue worker stopped")
+
+ if self.shutdown_signal_file.exists():
+ self.shutdown_signal_file.unlink()
+
+ def stop(self) -> None:
+ """Stop the worker loop."""
+ self.running = False
+
+ if self.current_process and self.current_process.poll() is None:
+ logger.info(f"Cancelling current job {self.current_job_id}")
+ self.current_process.terminate()
+ try:
+ self.current_process.wait(timeout=10)
+ except subprocess.TimeoutExpired:
+ self.current_process.kill()
+
+ if self.current_job_id:
+ self.db.update_job_status(self.current_job_id, JobStatus.CANCELLED, error_message="Worker shutdown")
+
+ def _process_job(self, job: dict) -> None:
+ """Process a single training job.
+
+ Args:
+ job: Job dictionary from database
+ """
+ job_id = job["id"]
+ self.current_job_id = job_id
+
+ logger.info(f"Starting job {job_id} for dataset: {job['dataset_name']}")
+
+ self.db.update_job_status(job_id, JobStatus.RUNNING)
+
+ try:
+ cmd = self._build_training_command(job)
+
+ self.current_process = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1,
+ )
+
+ self.db.update_job_status(job_id, JobStatus.RUNNING, process_id=self.current_process.pid)
+
+ log_lines = []
+ for line in self.current_process.stdout:
+ if not self.running:
+ break
+
+ current_job = self.db.get_job(job_id)
+ if current_job and current_job["status"] == JobStatus.CANCELLED:
+ logger.info(f"Job {job_id} was cancelled, terminating process")
+ self.current_process.terminate()
+ try:
+ self.current_process.wait(timeout=5)
+ except subprocess.TimeoutExpired:
+ self.current_process.kill()
+ break
+
+ log_lines.append(line)
+
+ if len(log_lines) > 500:
+ log_lines = log_lines[-500:]
+
+ if "Step" in line or "steps" in line.lower():
+ self.db.update_job_status(job_id, JobStatus.RUNNING, progress=line.strip(), logs="".join(log_lines))
+
+ returncode = self.current_process.wait()
+
+ final_logs = "".join(log_lines)
+
+ current_job = self.db.get_job(job_id)
+ if current_job and current_job["status"] == JobStatus.CANCELLED:
+ logger.info(f"Job {job_id} was cancelled")
+ self.db.update_job_status(
+ job_id,
+ JobStatus.CANCELLED,
+ logs=final_logs,
+ )
+ elif returncode == 0:
+ logger.info(f"Job {job_id} completed successfully")
+
+ output_dir = self._get_output_dir(job)
+ checkpoint_path = self._find_latest_checkpoint(output_dir)
+
+ self.db.update_job_status(
+ job_id,
+ JobStatus.COMPLETED,
+ progress="Training completed",
+ logs=final_logs,
+ output_dir=str(output_dir) if output_dir else None,
+ checkpoint_path=str(checkpoint_path) if checkpoint_path else None,
+ )
+ else:
+ logger.error(f"Job {job_id} failed with return code {returncode}")
+ self.db.update_job_status(
+ job_id,
+ JobStatus.FAILED,
+ error_message=f"Training process exited with code {returncode}",
+ logs=final_logs,
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing job {job_id}: {e}", exc_info=True)
+ current_job = self.db.get_job(job_id)
+ if current_job and current_job["status"] != JobStatus.CANCELLED:
+ self.db.update_job_status(job_id, JobStatus.FAILED, error_message=str(e))
+
+ finally:
+ self.current_process = None
+ self.current_job_id = None
+
+ def _build_training_command(self, job: dict) -> list[str]:
+ """Build the training command from job parameters.
+
+ Args:
+ job: Job dictionary
+
+ Returns:
+ Command as list of strings
+ """
+ params = job["params"]
+
+ python_exe = sys.executable
+
+ scripts_dir = Path(__file__).parent.parent
+ train_script = scripts_dir / "train_cli.py"
+
+ cmd = [
+ python_exe,
+ str(train_script),
+ "--job-id",
+ str(job["id"]),
+ "--dataset",
+ job["dataset_name"],
+ "--model-version",
+ params["model_source"],
+ "--learning-rate",
+ str(params["learning_rate"]),
+ "--steps",
+ str(params["steps"]),
+ "--lora-rank",
+ str(params["lora_rank"]),
+ "--batch-size",
+ str(params["batch_size"]),
+ "--width",
+ str(params["width"]),
+ "--height",
+ str(params["height"]),
+ "--num-frames",
+ str(params["num_frames"]),
+ "--validation-prompt",
+ params["validation_prompt"],
+ "--validation-interval",
+ str(params["validation_interval"]),
+ ]
+
+ if params.get("id_token"):
+ cmd.extend(["--id-token", params["id_token"]])
+
+ if params.get("push_to_hub"):
+ cmd.append("--push-to-hub")
+ if params.get("hf_model_id"):
+ cmd.extend(["--hf-model-id", params["hf_model_id"]])
+ if params.get("hf_token"):
+ os.environ["HF_TOKEN"] = params["hf_token"]
+
+ return cmd
+
+ def _get_output_dir(self, job: dict) -> Optional[Path]:
+ """Get the output directory for a job.
+
+ Args:
+ job: Job dictionary
+
+ Returns:
+ Output directory path or None
+ """
+ params = job["params"]
+ project_root = Path(__file__).parent.parent.parent
+ output_dir = project_root / "outputs" / f"lora_r{params['lora_rank']}_job{job['id']}"
+
+ if output_dir.exists():
+ return output_dir
+ return None
+
+ def _find_latest_checkpoint(self, output_dir: Optional[Path]) -> Optional[Path]:
+ """Find the latest checkpoint in the output directory.
+
+ Args:
+ output_dir: Output directory path
+
+ Returns:
+ Path to latest checkpoint or None
+ """
+ if not output_dir or not output_dir.exists():
+ return None
+
+ checkpoints_dir = output_dir / "checkpoints"
+ if not checkpoints_dir.exists():
+ return None
+
+ checkpoints = list(checkpoints_dir.glob("*.safetensors"))
+ if not checkpoints:
+ return None
+
+ return max(checkpoints, key=lambda p: p.stat().st_mtime)
+
+
+def main() -> None:
+ """Main entry point for the queue worker."""
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+ )
+
+ project_root = Path(__file__).parent.parent.parent
+ db_path = project_root / "jobs.db"
+
+ worker = QueueWorker(db_path)
+
+ try:
+ worker.start()
+ except KeyboardInterrupt:
+ logger.info("Received keyboard interrupt")
+ worker.stop()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/start_gradio_with_worker.py b/scripts/start_gradio_with_worker.py
new file mode 100755
index 0000000..7321a26
--- /dev/null
+++ b/scripts/start_gradio_with_worker.py
@@ -0,0 +1,185 @@
+#!/usr/bin/env python3
+"""
+Start both the queue worker and Gradio UI together.
+
+This script launches:
+1. Queue worker in the background to process training jobs
+2. Gradio UI for user interaction
+
+Both processes will be managed together, and stopping this script will stop both.
+"""
+
+import logging
+import signal
+import subprocess
+import sys
+import time
+from pathlib import Path
+
+# Configure logging for terminal output (alternative to print statements)
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(message)s", # Simple format for terminal output
+ handlers=[logging.StreamHandler(sys.stdout)],
+)
+logger = logging.getLogger(__name__)
+
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+
+class ServiceManager:
+ """Manage worker and UI processes."""
+
+ def __init__(self):
+ self.worker_process = None
+ self.ui_process = None
+ self.running = True
+
+ signal.signal(signal.SIGINT, self._handle_shutdown)
+ signal.signal(signal.SIGTERM, self._handle_shutdown)
+
+ def _handle_shutdown(self, _signum: int, _frame: object) -> None:
+ """Handle shutdown signals."""
+ self.stop()
+ sys.exit(0)
+
+ def start(self) -> bool:
+ """Start both worker and UI processes."""
+ python_exe = sys.executable
+ scripts_dir = Path(__file__).parent
+
+ logger.info("🚀 Starting LTX-Video Trainer with Queue System")
+ logger.info("=" * 60)
+ logger.info("\n📋 Starting job worker...")
+
+ worker_script = scripts_dir / "jobs" / "run_worker.py"
+ self.worker_process = subprocess.Popen(
+ [python_exe, str(worker_script)],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1,
+ )
+
+ time.sleep(2)
+
+ if self.worker_process.poll() is not None:
+ logger.error("❌ Worker failed to start")
+ if self.worker_process.stdout:
+ logger.error("\n📋 Worker Error Output:")
+ logger.error("-" * 60)
+ for line in self.worker_process.stdout:
+ logger.error(line.rstrip())
+ logger.error("-" * 60)
+ return False
+
+ logger.info("✅ Queue worker started")
+ logger.info("\n🎨 Starting Gradio UI...")
+
+ ui_script = scripts_dir / "app_gradio_v2.py"
+ self.ui_process = subprocess.Popen(
+ [python_exe, str(ui_script)],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ bufsize=1,
+ )
+
+ time.sleep(3)
+
+ if self.ui_process.poll() is not None:
+ logger.error("❌ UI failed to start")
+ if self.ui_process.stdout:
+ logger.error("\n📋 UI Error Output:")
+ logger.error("-" * 60)
+ for line in self.ui_process.stdout:
+ logger.error(line.rstrip())
+ logger.error("-" * 60)
+ self.stop()
+ return False
+
+ logger.info("✅ Gradio UI started")
+ logger.info("\n" + "=" * 60)
+ logger.info("🎉 All services running!")
+ logger.info("📍 Open your browser to: http://localhost:7860")
+ logger.info("💡 Press Ctrl+C to stop all services")
+ logger.info("=" * 60 + "\n")
+
+ return True
+
+ def monitor(self) -> None:
+ """Monitor both processes and restart if needed."""
+ while self.running:
+ if self.worker_process and self.worker_process.poll() is not None:
+ logger.error("⚠️ Queue worker stopped unexpectedly")
+ if self.worker_process.stdout:
+ remaining = self.worker_process.stdout.read()
+ if remaining:
+ logger.error("\n📋 Worker Error Output:")
+ logger.error("-" * 60)
+ logger.error(remaining)
+ logger.error("-" * 60)
+ self.running = False
+ break
+
+ if self.ui_process and self.ui_process.poll() is not None:
+ logger.error("⚠️ UI stopped unexpectedly")
+ if self.ui_process.stdout:
+ remaining = self.ui_process.stdout.read()
+ if remaining:
+ logger.error("\n📋 UI Error Output:")
+ logger.error("-" * 60)
+ logger.error(remaining)
+ logger.error("-" * 60)
+ self.running = False
+ break
+
+ if self.ui_process and self.ui_process.stdout:
+ line = self.ui_process.stdout.readline()
+ if line:
+ pass
+
+ if self.worker_process and self.worker_process.stdout:
+ line = self.worker_process.stdout.readline()
+ if line:
+ pass
+
+ time.sleep(0.1)
+
+ def stop(self) -> None:
+ """Stop both processes."""
+ logger.info("\n🛑 Shutting down services...")
+ self.running = False
+
+ if self.ui_process:
+ self.ui_process.terminate()
+ try:
+ self.ui_process.wait(timeout=5)
+ except subprocess.TimeoutExpired:
+ self.ui_process.kill()
+
+ if self.worker_process:
+ self.worker_process.terminate()
+ try:
+ self.worker_process.wait(timeout=10)
+ except subprocess.TimeoutExpired:
+ self.worker_process.kill()
+
+
+def main() -> None:
+ """Main entry point."""
+ manager = ServiceManager()
+
+ if not manager.start():
+ sys.exit(1)
+
+ try:
+ manager.monitor()
+ except KeyboardInterrupt:
+ pass
+ finally:
+ manager.stop()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/train_cli.py b/scripts/train_cli.py
new file mode 100755
index 0000000..82a083a
--- /dev/null
+++ b/scripts/train_cli.py
@@ -0,0 +1,242 @@
+#!/usr/bin/env python3
+"""
+Command-line training script that integrates with the job queue system.
+
+This script can be called directly with parameters or via the queue worker.
+"""
+
+import argparse
+import os
+import shutil
+import sys
+from pathlib import Path
+
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+import yaml
+from huggingface_hub import login
+
+from ltxv_trainer.config import LtxvTrainerConfig
+from ltxv_trainer.trainer import LtxvTrainer
+from scripts.jobs.database import JobDatabase, JobStatus
+
+
+def generate_config(args: argparse.Namespace, output_dir: Path) -> dict:
+ """Generate training configuration from command-line arguments.
+
+ Args:
+ args: Parsed command-line arguments
+ output_dir: Output directory for training
+
+ Returns:
+ Configuration dictionary
+ """
+ config = {
+ "model": {
+ "model_source": args.model_version,
+ "training_mode": "lora",
+ "load_checkpoint": None,
+ },
+ "lora": {
+ "rank": args.lora_rank,
+ "alpha": args.lora_rank, # Usually alpha = rank
+ "dropout": 0.0,
+ "target_modules": ["to_k", "to_q", "to_v", "to_out.0"],
+ },
+ "conditioning": {
+ "mode": "none",
+ "first_frame_conditioning_p": 0.1,
+ },
+ "optimization": {
+ "learning_rate": args.learning_rate,
+ "steps": args.steps,
+ "batch_size": args.batch_size,
+ "gradient_accumulation_steps": 1,
+ "max_grad_norm": 1.0,
+ "optimizer_type": "adamw",
+ "scheduler_type": "linear",
+ "scheduler_params": {},
+ "enable_gradient_checkpointing": False,
+ },
+ "acceleration": {
+ "mixed_precision_mode": "bf16",
+ "quantization": None,
+ "load_text_encoder_in_8bit": True,
+ "compile_with_inductor": False,
+ "compilation_mode": "reduce-overhead",
+ },
+ "data": {
+ "preprocessed_data_root": str(args.dataset_dir),
+ "num_dataloader_workers": 2,
+ },
+ "validation": {
+ "prompts": [args.validation_prompt],
+ "negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
+ "images": None,
+ "reference_videos": None,
+ "video_dims": [args.width, args.height, args.num_frames],
+ "seed": 42,
+ "inference_steps": 30,
+ "interval": args.validation_interval,
+ "videos_per_prompt": 1,
+ "guidance_scale": 3.5,
+ "skip_initial_validation": False,
+ },
+ "checkpoints": {
+ "interval": None,
+ "keep_last_n": 1,
+ },
+ "hub": {
+ "push_to_hub": args.push_to_hub,
+ "hub_model_id": args.hf_model_id if args.push_to_hub else None,
+ },
+ "flow_matching": {
+ "timestep_sampling_mode": "shifted_logit_normal",
+ "timestep_sampling_params": {},
+ },
+ "wandb": {
+ "enabled": False,
+ "project": "ltxv-trainer",
+ "entity": None,
+ "tags": [],
+ "log_validation_videos": True,
+ },
+ "seed": 42,
+ "output_dir": str(output_dir),
+ }
+
+ return config
+
+
+def setup_dataset(dataset_name: str, datasets_root: Path, training_data_dir: Path) -> None:
+ """Set up the training dataset by copying from managed datasets.
+
+ Args:
+ dataset_name: Name of the managed dataset
+ datasets_root: Root directory of managed datasets
+ training_data_dir: Temporary training data directory
+ """
+ managed_dataset_dir = datasets_root / dataset_name
+ managed_dataset_json = managed_dataset_dir / "dataset.json"
+
+ if not managed_dataset_json.exists():
+ raise ValueError(f"Dataset '{dataset_name}' not found")
+
+ if training_data_dir.exists():
+ shutil.rmtree(training_data_dir)
+ training_data_dir.mkdir(parents=True)
+
+ shutil.copy2(managed_dataset_json, training_data_dir / "dataset.json")
+
+ precomputed_dir = managed_dataset_dir / ".precomputed"
+ if precomputed_dir.exists():
+ shutil.copytree(precomputed_dir, training_data_dir / ".precomputed")
+
+
+def train_with_progress_callback(trainer: LtxvTrainer, db: JobDatabase, job_id: int) -> None:
+ """Train with progress updates to database.
+
+ Args:
+ trainer: Trainer instance
+ db: Job database
+ job_id: Current job ID
+ """
+
+ def progress_callback(step: int, total_steps: int, sampled_videos: list[Path] | None = None) -> None:
+ """Update job progress in database."""
+ progress_pct = (step / total_steps) * 100
+ progress_msg = f"Step {step}/{total_steps} ({progress_pct:.1f}%)"
+
+ validation_sample = None
+ if sampled_videos:
+ validation_sample = str(sampled_videos[0])
+
+ db.update_job_status(job_id, JobStatus.RUNNING, progress=progress_msg, validation_sample=validation_sample)
+
+ trainer.train(step_callback=progress_callback)
+
+
+def main() -> int | None:
+ """Main entry point for CLI training."""
+ parser = argparse.ArgumentParser(description="Train LTX-Video LoRA model")
+
+ parser.add_argument("--job-id", type=int, help="Job ID for queue tracking")
+
+ parser.add_argument("--dataset", required=True, help="Dataset name")
+
+ parser.add_argument("--model-version", default="0.9.1", help="Model version")
+
+ parser.add_argument("--learning-rate", type=float, default=2e-4, help="Learning rate")
+ parser.add_argument("--steps", type=int, default=1500, help="Training steps")
+ parser.add_argument("--lora-rank", type=int, default=128, help="LoRA rank")
+ parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
+
+ parser.add_argument("--width", type=int, default=768, help="Video width")
+ parser.add_argument("--height", type=int, default=768, help="Video height")
+ parser.add_argument("--num-frames", type=int, default=25, help="Number of frames")
+ parser.add_argument("--id-token", type=str, default="", help="LoRA ID token")
+
+ parser.add_argument(
+ "--validation-prompt",
+ default="a professional portrait video of a person",
+ help="Validation prompt",
+ )
+ parser.add_argument("--validation-interval", type=int, default=100, help="Validation interval")
+
+ parser.add_argument("--push-to-hub", action="store_true", help="Push to HuggingFace Hub")
+ parser.add_argument("--hf-model-id", type=str, help="HuggingFace model ID")
+
+ args = parser.parse_args()
+
+ project_root = Path(__file__).parent.parent
+ datasets_root = project_root / "datasets"
+ training_data_dir = (
+ project_root / "training_data" / f"job_{args.job_id}" if args.job_id else project_root / "training_data"
+ )
+ output_dir = (
+ project_root / "outputs" / f"lora_r{args.lora_rank}_job{args.job_id}"
+ if args.job_id
+ else project_root / "outputs" / f"lora_r{args.lora_rank}"
+ )
+
+ db = None
+ if args.job_id:
+ db_path = project_root / "jobs.db"
+ db = JobDatabase(db_path)
+
+ try:
+ if "HF_TOKEN" in os.environ:
+ login(token=os.environ["HF_TOKEN"])
+
+ args.dataset_dir = training_data_dir
+ setup_dataset(args.dataset, datasets_root, training_data_dir)
+
+ config = generate_config(args, output_dir)
+
+ output_dir.mkdir(parents=True, exist_ok=True)
+ config_path = output_dir / "training_config.yaml"
+ with open(config_path, "w") as f:
+ yaml.dump(config, f, indent=2)
+
+ trainer_config = LtxvTrainerConfig(**config)
+ trainer = LtxvTrainer(trainer_config)
+
+ if db and args.job_id:
+ train_with_progress_callback(trainer, db, args.job_id)
+ else:
+ trainer.train()
+
+ return 0
+
+ except Exception as e:
+ if db and args.job_id:
+ db.update_job_status(args.job_id, JobStatus.FAILED, error_message=str(e))
+ raise
+
+ finally:
+ if db:
+ db.close()
+
+
+if __name__ == "__main__":
+ sys.exit(main())