diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 2986fc139..9a2a74f51 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -122,17 +122,17 @@ jobs: # Actual tests encoder-test: - 'fastvideo/v1/models/encoders/**' - - 'fastvideo/v1/models/loaders/**' + - 'fastvideo/v1/models/loader/**' - 'fastvideo/v1/tests/encoders/**' - *common-paths vae-test: - 'fastvideo/v1/models/vaes/**' - - 'fastvideo/v1/models/loaders/**' + - 'fastvideo/v1/models/loader/**' - 'fastvideo/v1/tests/vaes/**' - *common-paths transformer-test: - 'fastvideo/v1/models/dits/**' - - 'fastvideo/v1/models/loaders/**' + - 'fastvideo/v1/models/loader/**' - 'fastvideo/v1/tests/transformers/**' - 'fastvideo/v1/layers/**' - 'fastvideo/v1/attention/**' diff --git a/examples/inference/basic/basic.py b/examples/inference/basic/basic.py index 4161004f1..c97f59e56 100644 --- a/examples/inference/basic/basic.py +++ b/examples/inference/basic/basic.py @@ -10,7 +10,7 @@ def main(): # attempt to identify the optimal arguments. generator = VideoGenerator.from_pretrained( "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", - # if num_gpus > 1, FastVideo will automatically handle distributed setup + # FastVideo will automatically handle distributed setup num_gpus=2, use_fsdp_inference=True, use_cpu_offload=False diff --git a/fastvideo/v1/configs/models/base.py b/fastvideo/v1/configs/models/base.py index 84b0de57c..40eb9ad66 100644 --- a/fastvideo/v1/configs/models/base.py +++ b/fastvideo/v1/configs/models/base.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field, fields -from typing import Any, Dict +from typing import Any, Dict, List, Tuple from fastvideo.v1.logger import init_logger @@ -12,7 +12,9 @@ # 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users @dataclass class ArchConfig: - pass + stacked_params_mapping: List[Tuple[str, str, str]] = field( + default_factory=list + ) # mapping from huggingface weight names to custom names @dataclass diff --git a/fastvideo/v1/configs/models/dits/stepvideo.py b/fastvideo/v1/configs/models/dits/stepvideo.py index abad243ec..78fc6b0b3 100644 --- a/fastvideo/v1/configs/models/dits/stepvideo.py +++ b/fastvideo/v1/configs/models/dits/stepvideo.py @@ -5,13 +5,11 @@ from fastvideo.v1.configs.models.dits.base import DiTArchConfig, DiTConfig -def is_blocks(n: str, m) -> bool: - return "blocks" in n and str.isdigit(n.split(".")[-1]) - - @dataclass class StepVideoArchConfig(DiTArchConfig): - _fsdp_shard_conditions: list = field(default_factory=lambda: [is_blocks]) + _fsdp_shard_conditions: list = field( + default_factory=lambda: + [lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit()]) _param_names_mapping: dict = field( default_factory=lambda: { diff --git a/fastvideo/v1/configs/models/encoders/base.py b/fastvideo/v1/configs/models/encoders/base.py index febbd23f2..d2e686add 100644 --- a/fastvideo/v1/configs/models/encoders/base.py +++ b/fastvideo/v1/configs/models/encoders/base.py @@ -32,8 +32,11 @@ class TextEncoderArchConfig(EncoderArchConfig): output_past: bool = True scalable_attention: bool = True tie_word_embeddings: bool = False - + stacked_params_mapping: List[Tuple[str, str, str]] = field( + default_factory=list + ) # mapping from huggingface weight names to custom names tokenizer_kwargs: Dict[str, Any] = field(default_factory=dict) + _fsdp_shard_conditions: list = field(default_factory=lambda: []) def __post_init__(self) -> None: self.tokenizer_kwargs = { diff --git a/fastvideo/v1/configs/models/encoders/clip.py b/fastvideo/v1/configs/models/encoders/clip.py index 6e81d41e2..ab9340c02 100644 --- a/fastvideo/v1/configs/models/encoders/clip.py +++ b/fastvideo/v1/configs/models/encoders/clip.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Optional +from typing import List, Optional, Tuple from fastvideo.v1.configs.models.encoders.base import (ImageEncoderArchConfig, ImageEncoderConfig, @@ -8,6 +8,14 @@ TextEncoderConfig) +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embeddings") + + @dataclass class CLIPTextArchConfig(TextEncoderArchConfig): vocab_size: int = 49408 @@ -27,6 +35,15 @@ class CLIPTextArchConfig(TextEncoderArchConfig): bos_token_id: int = 49406 eos_token_id: int = 49407 text_len: int = 77 + stacked_params_mapping: List[Tuple[str, str, + str]] = field(default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ]) + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_transformer_layer, _is_embeddings]) @dataclass @@ -45,6 +62,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig): attention_dropout: float = 0.0 initializer_range: float = 0.02 initializer_factor: float = 1.0 + stacked_params_mapping: List[Tuple[str, str, + str]] = field(default_factory=lambda: [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ]) @dataclass diff --git a/fastvideo/v1/configs/models/encoders/llama.py b/fastvideo/v1/configs/models/encoders/llama.py index 1fde6e185..0901e98ae 100644 --- a/fastvideo/v1/configs/models/encoders/llama.py +++ b/fastvideo/v1/configs/models/encoders/llama.py @@ -1,11 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Optional +from typing import List, Optional, Tuple from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig, TextEncoderConfig) +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + @dataclass class LlamaArchConfig(TextEncoderArchConfig): vocab_size: int = 32000 @@ -32,6 +44,18 @@ class LlamaArchConfig(TextEncoderArchConfig): head_dim: Optional[int] = None hidden_state_skip_layer: int = 2 text_len: int = 256 + stacked_params_mapping: List[Tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), # type: ignore + (".gate_up_proj", ".up_proj", 1), # type: ignore + ]) + _fsdp_shard_conditions: list = field( + default_factory=lambda: + [_is_transformer_layer, _is_embeddings, _is_final_norm]) @dataclass diff --git a/fastvideo/v1/configs/models/encoders/t5.py b/fastvideo/v1/configs/models/encoders/t5.py index 7ec4d4a1b..79e9c9ad0 100644 --- a/fastvideo/v1/configs/models/encoders/t5.py +++ b/fastvideo/v1/configs/models/encoders/t5.py @@ -1,11 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import dataclass, field -from typing import Optional +from typing import List, Optional, Tuple from fastvideo.v1.configs.models.encoders.base import (TextEncoderArchConfig, TextEncoderConfig) +def _is_transformer_layer(n: str, m) -> bool: + return "block" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("shared") + + +def _is_final_layernorm(n: str, m) -> bool: + return n.endswith("final_layer_norm") + + @dataclass class T5ArchConfig(TextEncoderArchConfig): vocab_size: int = 32128 @@ -29,6 +41,16 @@ class T5ArchConfig(TextEncoderArchConfig): eos_token_id: int = 1 classifier_dropout: float = 0.0 text_len: int = 512 + stacked_params_mapping: List[Tuple[str, str, + str]] = field(default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q", "q"), + (".qkv_proj", ".k", "k"), + (".qkv_proj", ".v", "v"), + ]) + _fsdp_shard_conditions: list = field( + default_factory=lambda: + [_is_transformer_layer, _is_embeddings, _is_final_layernorm]) # Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py def __post_init__(self): diff --git a/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py b/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py index 57cf092ca..f808d2f09 100644 --- a/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py +++ b/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_iterable_style.py @@ -11,7 +11,7 @@ build_parquet_iterable_style_dataloader) from fastvideo.v1.distributed import get_world_rank from fastvideo.v1.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_torch_device, + cleanup_dist_env_and_memory, get_local_torch_device, maybe_init_distributed_environment_and_model_parallel) from fastvideo.v1.logger import init_logger @@ -148,8 +148,8 @@ def main() -> None: break # Move data to device - latents = latents.to(get_torch_device()) - embeddings = embeddings.to(get_torch_device()) + latents = latents.to(get_local_torch_device()) + embeddings = embeddings.to(get_local_torch_device()) # Calculate actual batch size batch_size = latents.size(0) diff --git a/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py b/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py index a2614edda..7618471ea 100644 --- a/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py +++ b/fastvideo/v1/dataset/benchmarks/benchmark_parquet_dataset_map_style.py @@ -13,7 +13,7 @@ build_parquet_map_style_dataloader) from fastvideo.v1.distributed import get_world_rank from fastvideo.v1.distributed.parallel_state import ( - cleanup_dist_env_and_memory, get_torch_device, + cleanup_dist_env_and_memory, get_local_torch_device, maybe_init_distributed_environment_and_model_parallel) from fastvideo.v1.logger import init_logger @@ -165,8 +165,8 @@ def main() -> None: break # Move data to device - latents = latents.to(get_torch_device()) - embeddings = embeddings.to(get_torch_device()) + latents = latents.to(get_local_torch_device()) + embeddings = embeddings.to(get_local_torch_device()) # Calculate actual batch size batch_size = latents.size(0) diff --git a/fastvideo/v1/distributed/__init__.py b/fastvideo/v1/distributed/__init__.py index 5c0a1af6e..7e96bafa1 100644 --- a/fastvideo/v1/distributed/__init__.py +++ b/fastvideo/v1/distributed/__init__.py @@ -3,10 +3,10 @@ from fastvideo.v1.distributed.communication_op import * from fastvideo.v1.distributed.parallel_state import ( cleanup_dist_env_and_memory, get_dp_group, get_dp_rank, get_dp_world_size, - get_sp_group, get_sp_parallel_rank, get_sp_world_size, get_torch_device, - get_tp_group, get_tp_rank, get_tp_world_size, get_world_group, - get_world_rank, get_world_size, init_distributed_environment, - initialize_model_parallel, + get_local_torch_device, get_sp_group, get_sp_parallel_rank, + get_sp_world_size, get_tp_group, get_tp_rank, get_tp_world_size, + get_world_group, get_world_rank, get_world_size, + init_distributed_environment, initialize_model_parallel, maybe_init_distributed_environment_and_model_parallel, model_parallel_is_initialized) from fastvideo.v1.distributed.utils import * @@ -40,5 +40,5 @@ "get_tp_world_size", # Get torch device - "get_torch_device", + "get_local_torch_device", ] diff --git a/fastvideo/v1/distributed/parallel_state.py b/fastvideo/v1/distributed/parallel_state.py index b15a9f6c0..8c992f2a5 100644 --- a/fastvideo/v1/distributed/parallel_state.py +++ b/fastvideo/v1/distributed/parallel_state.py @@ -904,7 +904,7 @@ def get_dp_rank() -> int: return get_dp_group().rank_in_group -def get_torch_device() -> torch.device: +def get_local_torch_device() -> torch.device: """Return the torch device for the current rank.""" return torch.device(f"cuda:{envs.LOCAL_RANK}") @@ -1232,4 +1232,4 @@ def initialize_sequence_parallel_group( backend, group_name=group_name) - return sp_group + return sp_group \ No newline at end of file diff --git a/fastvideo/v1/fastvideo_args.py b/fastvideo/v1/fastvideo_args.py index cc4a30fad..616466e70 100644 --- a/fastvideo/v1/fastvideo_args.py +++ b/fastvideo/v1/fastvideo_args.py @@ -58,8 +58,10 @@ class FastVideoArgs: output_type: str = "pil" - use_cpu_offload: bool = True + use_cpu_offload: bool = True # For DiT use_fsdp_inference: bool = True + text_encoder_offload: bool = True + pin_cpu_memory: bool = True # STA (Sliding Tile Attention) parameters mask_strategy_file_path: Optional[str] = None @@ -208,7 +210,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--use-cpu-offload", action=StoreBoolean, help= - "Use CPU offload for model inference. Enable if run out of memory with FSDP.", + "Use CPU offload for DiT inference. Enable if run out of memory with FSDP.", ) parser.add_argument( "--use-fsdp-inference", @@ -216,7 +218,19 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help= "Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.", ) - + parser.add_argument( + "--text-encoder-cpu-offload", + action=StoreBoolean, + help= + "Use CPU offload for text encoder. Enable if run out of memory.", + ) + parser.add_argument( + "--pin-cpu-memory", + action=StoreBoolean, + help= + "Pin memory for CPU offload. Only added as a temp workaround if it throws \"CUDA error: invalid argument\". " + "Should be enabled in almost all cases", + ) parser.add_argument( "--disable-autocast", action=StoreBoolean, diff --git a/fastvideo/v1/models/dits/stepvideo.py b/fastvideo/v1/models/dits/stepvideo.py index c70f1c090..d0ad9854a 100644 --- a/fastvideo/v1/models/dits/stepvideo.py +++ b/fastvideo/v1/models/dits/stepvideo.py @@ -455,10 +455,7 @@ def forward(self, class StepVideoModel(BaseDiT): # (Optional) Keep the same attribute for compatibility with splitting, etc. - _fsdp_shard_conditions = [ - lambda n, m: "transformer_blocks" in n and n.split(".")[-1].isdigit(), - # lambda n, m: "pos_embed" in n # If needed for the patch embedding. - ] + _fsdp_shard_conditions = StepVideoConfig()._fsdp_shard_conditions _param_names_mapping = StepVideoConfig()._param_names_mapping _reverse_param_names_mapping = StepVideoConfig( )._reverse_param_names_mapping diff --git a/fastvideo/v1/models/dits/wanvideo.py b/fastvideo/v1/models/dits/wanvideo.py index 62d542c18..039f57aba 100644 --- a/fastvideo/v1/models/dits/wanvideo.py +++ b/fastvideo/v1/models/dits/wanvideo.py @@ -318,9 +318,9 @@ def forward( value, _ = self.to_v(norm_hidden_states) if self.norm_q is not None: - query = self.norm_q.forward_native(query) + query = self.norm_q(query) if self.norm_k is not None: - key = self.norm_k.forward_native(key) + key = self.norm_k(key) query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) @@ -465,9 +465,9 @@ def forward( gate_compress, _ = self.to_gate_compress(norm_hidden_states) if self.norm_q is not None: - query = self.norm_q.forward_native(query) + query = self.norm_q(query) if self.norm_k is not None: - key = self.norm_k.forward_native(key) + key = self.norm_k(key) query = query.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) key = key.squeeze(1).unflatten(2, (self.num_attention_heads, -1)) diff --git a/fastvideo/v1/models/encoders/base.py b/fastvideo/v1/models/encoders/base.py index 4c7c45ec2..69b3a4846 100644 --- a/fastvideo/v1/models/encoders/base.py +++ b/fastvideo/v1/models/encoders/base.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod -from typing import Optional, Tuple +from dataclasses import field +from typing import List, Optional, Tuple import torch from torch import nn @@ -12,6 +13,9 @@ class TextEncoder(nn.Module, ABC): + _fsdp_shard_conditions: list = field(default_factory=lambda: []) + _stacked_params_mapping: List[Tuple[str, str, + str]] = field(default_factory=list) _supported_attention_backends: Tuple[ AttentionBackendEnum, ...] = TextEncoderConfig()._supported_attention_backends @@ -19,6 +23,8 @@ class TextEncoder(nn.Module, ABC): def __init__(self, config: TextEncoderConfig) -> None: super().__init__() self.config = config + self._fsdp_shard_conditions = config._fsdp_shard_conditions + self._stacked_params_mapping = config.arch_config.stacked_params_mapping if not self.supported_attention_backends: raise ValueError( f"Subclass {self.__class__.__name__} must define _supported_attention_backends" diff --git a/fastvideo/v1/models/encoders/clip.py b/fastvideo/v1/models/encoders/clip.py index ecbaba58d..8278e0f71 100644 --- a/fastvideo/v1/models/encoders/clip.py +++ b/fastvideo/v1/models/encoders/clip.py @@ -596,12 +596,7 @@ def device(self): # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ] + params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() layer_count = len(self.vision_model.encoder.layers) @@ -620,7 +615,8 @@ def load_weights(self, weights: Iterable[Tuple[str, if layer_idx >= layer_count: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for (param_name, weight_name, + shard_id) in self.config.arch_config.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/fastvideo/v1/models/encoders/llama.py b/fastvideo/v1/models/encoders/llama.py index ebf009bf1..2fa32780d 100644 --- a/fastvideo/v1/models/encoders/llama.py +++ b/fastvideo/v1/models/encoders/llama.py @@ -369,14 +369,7 @@ def forward( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] + params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: @@ -406,7 +399,7 @@ def load_weights(self, weights: Iterable[Tuple[str, continue else: name = kv_scale_name - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_id in self.config.arch_config.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/fastvideo/v1/models/encoders/t5.py b/fastvideo/v1/models/encoders/t5.py index a4ea46c40..12cb9bd40 100644 --- a/fastvideo/v1/models/encoders/t5.py +++ b/fastvideo/v1/models/encoders/t5.py @@ -124,7 +124,7 @@ def __init__(self, self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon) def forward(self, hidden_states) -> torch.Tensor: - forwarded_states = self.layer_norm.forward_native(hidden_states) + forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) hidden_states = hidden_states + forwarded_states return hidden_states @@ -362,7 +362,7 @@ def forward( attention_mask: torch.Tensor, attn_metadata: Optional[AttentionMetadata] = None, ) -> torch.Tensor: - normed_hidden_states = self.layer_norm.forward_native(hidden_states) + normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( hidden_states=normed_hidden_states, attention_mask=attention_mask, @@ -391,7 +391,7 @@ def forward( hidden_states: torch.Tensor, attn_metadata: Optional[AttentionMetadata] = None, ) -> torch.Tensor: - normed_hidden_states = self.layer_norm.forward_native(hidden_states) + normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( hidden_states=normed_hidden_states, attn_metadata=attn_metadata, @@ -494,7 +494,7 @@ def forward( attention_mask=attention_mask, attn_metadata=attn_metadata, ) - hidden_states = self.final_layer_norm.forward_native(hidden_states) + hidden_states = self.final_layer_norm(hidden_states) return hidden_states @@ -631,19 +631,13 @@ def forward( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q", "q"), - (".qkv_proj", ".k", "k"), - (".qkv_proj", ".v", "v"), - ] params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: loaded = False if "decoder" in name or "lm_head" in name: continue - for param_name, weight_name, shard_id in stacked_params_mapping: + for param_name, weight_name, shard_id in self.config.arch_config.stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/fastvideo/v1/models/loader/component_loader.py b/fastvideo/v1/models/loader/component_loader.py index 270bc8387..16bbf1efc 100644 --- a/fastvideo/v1/models/loader/component_loader.py +++ b/fastvideo/v1/models/loader/component_loader.py @@ -7,20 +7,23 @@ import time from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Generator, Iterable, List, Optional, Tuple, cast +from typing import Generator, Iterable, List, Optional, Tuple, cast import torch +import torch.distributed as dist import torch.nn as nn from safetensors.torch import load_file as safetensors_load_file +from torch.distributed import init_device_mesh from transformers import AutoImageProcessor, AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from fastvideo.v1.configs.models import EncoderConfig -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.hf_transformer_utils import get_diffusers_config -from fastvideo.v1.models.loader.fsdp_load import maybe_load_fsdp_model +from fastvideo.v1.models.loader.fsdp_load import (maybe_load_fsdp_model, + shard_model) from fastvideo.v1.models.loader.utils import set_default_torch_dtype from fastvideo.v1.models.loader.weight_utils import ( filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, @@ -38,14 +41,12 @@ def __init__(self, device=None) -> None: self.device = device @abstractmethod - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): + def load(self, model_path: str, fastvideo_args: FastVideoArgs): """ Load the component based on the model path, architecture, and inference args. Args: model_path: Path to the component model - architecture: Architecture of the component model fastvideo_args: FastVideoArgs Returns: @@ -163,16 +164,18 @@ def _prepare_weights( return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, source: "Source" - ) -> Generator[Tuple[str, torch.Tensor], None, None]: + self, source: "Source", + to_cpu: bool) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.fall_back_to_pt, source.allow_patterns_overrides) if use_safetensors: - weights_iterator = safetensors_weights_iterator(hf_weights_files) + weights_iterator = safetensors_weights_iterator(hf_weights_files, + to_cpu=to_cpu) else: - weights_iterator = pt_weights_iterator(hf_weights_files) + weights_iterator = pt_weights_iterator(hf_weights_files, + to_cpu=to_cpu) if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() @@ -182,9 +185,9 @@ def _get_weights_iterator( def _get_all_weights( self, - model_config: Any, model: nn.Module, model_path: str, + to_cpu: bool, ) -> Generator[Tuple[str, torch.Tensor], None, None]: primary_weights = TextEncoderLoader.Source( model_path, @@ -193,18 +196,17 @@ def _get_all_weights( allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) - yield from self._get_weights_iterator(primary_weights) + yield from self._get_weights_iterator(primary_weights, to_cpu) secondary_weights = cast( Iterable[TextEncoderLoader.Source], getattr(model, "secondary_weights", ()), ) for source in secondary_weights: - yield from self._get_weights_iterator(source) + yield from self._get_weights_iterator(source, to_cpu) - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): - """Load the text encoders based on the model path, architecture, and inference args.""" + def load(self, model_path: str, fastvideo_args: FastVideoArgs): + """Load the text encoders based on the model path, and inference args.""" # model_config: PretrainedConfig = get_hf_config( # model=model_path, # trust_remote_code=fastvideo_args.trust_remote_code, @@ -233,16 +235,23 @@ def load(self, model_path: str, architecture: str, encoder_precision = fastvideo_args.pipeline_config.text_encoder_precisions[ 1] - target_device = get_torch_device() + target_device = get_local_torch_device() # TODO(will): add support for other dtypes return self.load_model(model_path, encoder_config, target_device, - encoder_precision) + fastvideo_args, encoder_precision) def load_model(self, model_path: str, model_config: EncoderConfig, target_device: torch.device, + fastvideo_args: FastVideoArgs, dtype: str = "fp16"): + use_cpu_offload = fastvideo_args.text_encoder_offload and len( + getattr(model_config, "_fsdp_shard_conditions", [])) > 0 + + if fastvideo_args.text_encoder_offload: + target_device = torch.device("cpu") + with set_default_torch_dtype(PRECISION_TO_TYPE[dtype]): with target_device: architectures = getattr(model_config, "architectures", []) @@ -251,12 +260,26 @@ def load_model(self, weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( - self._get_all_weights(model_config, model, model_path)) + self._get_all_weights(model, model_path, + to_cpu=use_cpu_offload)) self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", self.counter_after_loading_weights - self.counter_before_loading_weights) + + if use_cpu_offload: + mesh = init_device_mesh( + "cuda", + mesh_shape=(1, dist.get_world_size()), + mesh_dim_names=("offload", "replicate"), + ) + shard_model(model, + cpu_offload=True, + reshard_after_forward=True, + mesh=mesh["offload"], + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=fastvideo_args.pin_cpu_memory) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. # if loaded_weights is not None: @@ -270,9 +293,8 @@ def load_model(self, class ImageEncoderLoader(TextEncoderLoader): - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): - """Load the text encoders based on the model path, architecture, and inference args.""" + def load(self, model_path: str, fastvideo_args: FastVideoArgs): + """Load the text encoders based on the model path, and inference args.""" # model_config: PretrainedConfig = get_hf_config( # model=model_path, # trust_remote_code=fastvideo_args.trust_remote_code, @@ -290,19 +312,18 @@ def load(self, model_path: str, architecture: str, encoder_config = fastvideo_args.pipeline_config.image_encoder_config encoder_config.update_model_arch(model_config) - target_device = get_torch_device() + target_device = get_local_torch_device() # TODO(will): add support for other dtypes return self.load_model( - model_path, encoder_config, target_device, + model_path, encoder_config, target_device, fastvideo_args, fastvideo_args.pipeline_config.image_encoder_precision) class ImageProcessorLoader(ComponentLoader): """Loader for image processor.""" - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): - """Load the image processor based on the model path, architecture, and inference args.""" + def load(self, model_path: str, fastvideo_args: FastVideoArgs): + """Load the image processor based on the model path, and inference args.""" logger.info("Loading image processor from %s", model_path) image_processor = AutoImageProcessor.from_pretrained(model_path, ) @@ -314,9 +335,8 @@ def load(self, model_path: str, architecture: str, class TokenizerLoader(ComponentLoader): """Loader for tokenizers.""" - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): - """Load the tokenizer based on the model path, architecture, and inference args.""" + def load(self, model_path: str, fastvideo_args: FastVideoArgs): + """Load the tokenizer based on the model path, and inference args.""" logger.info("Loading tokenizer from %s", model_path) tokenizer = AutoTokenizer.from_pretrained( @@ -333,9 +353,8 @@ def load(self, model_path: str, architecture: str, class VAELoader(ComponentLoader): """Loader for VAE.""" - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): - """Load the VAE based on the model path, architecture, and inference args.""" + def load(self, model_path: str, fastvideo_args: FastVideoArgs): + """Load the VAE based on the model path, and inference args.""" config = get_diffusers_config(model=model_path) class_name = config.pop("_class_name") assert class_name is not None, "Model config does not contain a _class_name attribute. Only diffusers format is supported." @@ -346,7 +365,7 @@ def load(self, model_path: str, architecture: str, with set_default_torch_dtype(PRECISION_TO_TYPE[ fastvideo_args.pipeline_config.vae_precision]): vae_cls, _ = ModelRegistry.resolve_model_cls(class_name) - vae = vae_cls(vae_config).to(get_torch_device()) + vae = vae_cls(vae_config).to(get_local_torch_device()) # Find all safetensors files safetensors_list = glob.glob( @@ -365,9 +384,8 @@ def load(self, model_path: str, architecture: str, class TransformerLoader(ComponentLoader): """Loader for transformer.""" - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): - """Load the transformer based on the model path, architecture, and inference args.""" + def load(self, model_path: str, fastvideo_args: FastVideoArgs): + """Load the transformer based on the model path, and inference args.""" config = get_diffusers_config(model=model_path) hf_config = deepcopy(config) cls_name = config.pop("_class_name") @@ -405,7 +423,7 @@ def load(self, model_path: str, architecture: str, "hf_config": hf_config }, weight_dir_list=safetensors_list, - device=get_torch_device(), + device=get_local_torch_device(), hsdp_replicate_dim=fastvideo_args.hsdp_replicate_dim, hsdp_shard_dim=fastvideo_args.hsdp_shard_dim, cpu_offload=fastvideo_args.use_cpu_offload, @@ -449,9 +467,8 @@ def load(self, model_path: str, architecture: str, class SchedulerLoader(ComponentLoader): """Loader for scheduler.""" - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): - """Load the scheduler based on the model path, architecture, and inference args.""" + def load(self, model_path: str, fastvideo_args: FastVideoArgs): + """Load the scheduler based on the model path, and inference args.""" config = get_diffusers_config(model=model_path) class_name = config.pop("_class_name") @@ -475,9 +492,8 @@ def __init__(self, library="transformers") -> None: super().__init__() self.library = library - def load(self, model_path: str, architecture: str, - fastvideo_args: FastVideoArgs): - """Load a generic component based on the model path, architecture, and inference args.""" + def load(self, model_path: str, fastvideo_args: FastVideoArgs): + """Load a generic component based on the model path, and inference args.""" logger.warning("Using generic loader for %s with library %s", model_path, self.library) @@ -513,7 +529,7 @@ class PipelineComponentLoader: @staticmethod def load_module(module_name: str, component_model_path: str, - transformers_or_diffusers: str, architecture: str, + transformers_or_diffusers: str, fastvideo_args: FastVideoArgs): """ Load a pipeline module. @@ -522,7 +538,6 @@ def load_module(module_name: str, component_model_path: str, module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler") component_model_path: Path to the component model transformers_or_diffusers: Whether the module is from transformers or diffusers - architecture: Architecture of the component model pipeline_args: Inference arguments Returns: @@ -540,4 +555,4 @@ def load_module(module_name: str, component_model_path: str, transformers_or_diffusers) # Load the module - return loader.load(component_model_path, architecture, fastvideo_args) + return loader.load(component_model_path, fastvideo_args) diff --git a/fastvideo/v1/models/loader/fsdp_load.py b/fastvideo/v1/models/loader/fsdp_load.py index a9c890f69..23375baf8 100644 --- a/fastvideo/v1/models/loader/fsdp_load.py +++ b/fastvideo/v1/models/loader/fsdp_load.py @@ -69,6 +69,7 @@ def maybe_load_fsdp_model( fsdp_inference: bool = False, output_dtype: Optional[torch.dtype] = None, training_mode: bool = True, + pin_cpu_memory: bool = True, ) -> torch.nn.Module: """ Load the model with FSDP if is training, else load the model without FSDP. @@ -101,9 +102,12 @@ def maybe_load_fsdp_model( cpu_offload=cpu_offload, reshard_after_forward=True, mp_policy=mp_policy, - mesh=device_mesh) + mesh=device_mesh, + fsdp_shard_conditions=model._fsdp_shard_conditions, + pin_cpu_memory=pin_cpu_memory) - weight_iterator = safetensors_weights_iterator(weight_dir_list) + weight_iterator = safetensors_weights_iterator(weight_dir_list, + to_cpu=cpu_offload) param_names_mapping_fn = get_param_names_mapping(model._param_names_mapping) load_model_from_full_model_state_dict( model, @@ -129,9 +133,10 @@ def shard_model( *, cpu_offload: bool, reshard_after_forward: bool = True, - mp_policy: Optional[MixedPrecisionPolicy] = None, - dp_mesh: Optional[DeviceMesh] = None, + mp_policy: Optional[MixedPrecisionPolicy] = MixedPrecisionPolicy(), # noqa mesh: Optional[DeviceMesh] = None, + fsdp_shard_conditions: List[Callable[[str, nn.Module], bool]] = [], # noqa + pin_cpu_memory: bool = True, ) -> None: """ Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API. @@ -150,19 +155,29 @@ def shard_model( reshard_after_forward (bool): Whether to reshard parameters and buffers after the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy. - dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism. + mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism. Default to None. + fsdp_shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine + which modules to shard with FSDP. + pin_cpu_memory (bool): If set to True, FSDP will pin the CPU memory of the offloaded parameters. Raises: ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered. """ + if fsdp_shard_conditions is None or len(fsdp_shard_conditions) == 0: + logger.warning( + "The FSDP shard condition list is empty or None. No modules will be sharded in %s", + type(model).__name__) + return + fsdp_kwargs = { "reshard_after_forward": reshard_after_forward, "mesh": mesh, "mp_policy": mp_policy, } if cpu_offload: - fsdp_kwargs["offload_policy"] = CPUOffloadPolicy() + fsdp_kwargs["offload_policy"] = CPUOffloadPolicy( + pin_memory=pin_cpu_memory) # iterating in reverse to start with # lowest-level modules first @@ -172,7 +187,7 @@ def shard_model( for n, m in reversed(list(model.named_modules())): if any([ shard_condition(n, m) - for shard_condition in model._fsdp_shard_conditions + for shard_condition in fsdp_shard_conditions ]): fully_shard(m, **fsdp_kwargs) num_layers_sharded += 1 diff --git a/fastvideo/v1/models/loader/weight_utils.py b/fastvideo/v1/models/loader/weight_utils.py index b939ab5c5..bb9d668e2 100644 --- a/fastvideo/v1/models/loader/weight_utils.py +++ b/fastvideo/v1/models/loader/weight_utils.py @@ -14,6 +14,7 @@ from safetensors.torch import safe_open from tqdm.auto import tqdm +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.logger import init_logger logger = init_logger(__name__) @@ -118,27 +119,31 @@ def filter_files_not_needed_for_inference( def safetensors_weights_iterator( - hf_weights_files: List[str] + hf_weights_files: List[str], + to_cpu: bool = True, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files.""" enable_tqdm = not torch.distributed.is_initialized( ) or torch.distributed.get_rank() == 0 + device = "cpu" if to_cpu else str(get_local_torch_device()) for st_file in tqdm( hf_weights_files, desc="Loading safetensors checkpoint shards", disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - with safe_open(st_file, framework="pt") as f: + with safe_open(st_file, framework="pt", device=device) as f: for name in f.keys(): # noqa: SIM118 param = f.get_tensor(name) yield name, param def pt_weights_iterator( - hf_weights_files: List[str] + hf_weights_files: List[str], + to_cpu: bool = True, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model bin/pt files.""" + device = "cpu" if to_cpu else str(get_local_torch_device()) enable_tqdm = not torch.distributed.is_initialized( ) or torch.distributed.get_rank() == 0 for bin_file in tqdm( @@ -147,7 +152,7 @@ def pt_weights_iterator( disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, map_location="cpu", weights_only=True) + state = torch.load(bin_file, map_location=device, weights_only=True) yield from state.items() del state diff --git a/fastvideo/v1/pipelines/composed_pipeline_base.py b/fastvideo/v1/pipelines/composed_pipeline_base.py index 55dae4ccf..2d33ff621 100644 --- a/fastvideo/v1/pipelines/composed_pipeline_base.py +++ b/fastvideo/v1/pipelines/composed_pipeline_base.py @@ -249,7 +249,6 @@ def load_modules( module_name=module_name, component_model_path=component_model_path, transformers_or_diffusers=transformers_or_diffusers, - architecture=architecture, fastvideo_args=fastvideo_args, ) logger.info("Loaded module %s from %s", module_name, diff --git a/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py index 5866c46ee..8db6f1d28 100644 --- a/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py +++ b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_base.py @@ -18,7 +18,7 @@ from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset import ValidationDataset, getdataset from fastvideo.v1.dataset.preprocessing_datasets import PreprocessBatch -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.composed_pipeline_base import ComposedPipelineBase @@ -328,7 +328,8 @@ def preprocess_video_and_text(self, fastvideo_args: FastVideoArgs, args): # VAE with torch.autocast("cuda", dtype=torch.float32): latents = self.get_module("vae").encode( - valid_data["pixel_values"].to(get_torch_device())).mean + valid_data["pixel_values"].to( + get_local_torch_device())).mean # Get extra features if needed extra_features = self.get_extra_features( diff --git a/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_i2v.py b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_i2v.py index 58ba09bef..286aa9dfd 100644 --- a/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_i2v.py +++ b/fastvideo/v1/pipelines/preprocess/preprocess_pipeline_i2v.py @@ -13,7 +13,7 @@ from PIL import Image from fastvideo.v1.dataset.dataloader.schema import pyarrow_schema_i2v -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.models.vision_utils import (get_default_height_width, @@ -82,8 +82,8 @@ def get_extra_features(self, valid_data: Dict[str, Any], fastvideo_args: FastVideoArgs) -> Dict[str, Any]: # TODO(will): move these to cpu at some point - self.get_module("image_encoder").to(get_torch_device()) - self.get_module("vae").to(get_torch_device()) + self.get_module("image_encoder").to(get_local_torch_device()) + self.get_module("vae").to(get_local_torch_device()) features = {} """Get CLIP features from the first frame of each video.""" @@ -107,7 +107,7 @@ def get_extra_features(self, valid_data: Dict[str, Any], # Get CLIP features pixel_values = torch.cat( [img['pixel_values'] for img in processed_images], - dim=0).to(get_torch_device()) + dim=0).to(get_local_torch_device()) with torch.no_grad(): image_inputs = {'pixel_values': pixel_values} with set_forward_context(current_timestep=0, attn_metadata=None): @@ -129,8 +129,8 @@ def get_extra_features(self, valid_data: Dict[str, Any], height, width) ], dim=2) - video_condition = video_condition.to(device=get_torch_device(), - dtype=torch.float32) + video_condition = video_condition.to( + device=get_local_torch_device(), dtype=torch.float32) video_conditions.append(video_condition) video_conditions = torch.cat(video_conditions, dim=0) diff --git a/fastvideo/v1/pipelines/stages/decoding.py b/fastvideo/v1/pipelines/stages/decoding.py index ea75f7473..0da9e461d 100644 --- a/fastvideo/v1/pipelines/stages/decoding.py +++ b/fastvideo/v1/pipelines/stages/decoding.py @@ -5,7 +5,7 @@ import torch -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.vaes.common import ParallelTiledVAE @@ -61,7 +61,7 @@ def forward( Returns: The batch with decoded outputs. """ - self.vae = self.vae.to(get_torch_device()) + self.vae = self.vae.to(get_local_torch_device()) latents = batch.latents # TODO(will): remove this once we add input/output validation for stages @@ -119,6 +119,5 @@ def forward( self.maybe_free_model_hooks() self.vae.to("cpu") - torch.cuda.empty_cache() return batch diff --git a/fastvideo/v1/pipelines/stages/denoising.py b/fastvideo/v1/pipelines/stages/denoising.py index 2070bf93e..6fc7b79c2 100644 --- a/fastvideo/v1/pipelines/stages/denoising.py +++ b/fastvideo/v1/pipelines/stages/denoising.py @@ -12,8 +12,9 @@ from fastvideo.v1.attention import get_attn_backend from fastvideo.v1.configs.pipelines.base import STA_Mode -from fastvideo.v1.distributed import (get_sp_parallel_rank, get_sp_world_size, - get_torch_device, get_world_group) +from fastvideo.v1.distributed import (get_local_torch_device, + get_sp_parallel_rank, get_sp_world_size, + get_world_group) from fastvideo.v1.distributed.communication_op import ( sequence_model_parallel_all_gather) from fastvideo.v1.fastvideo_args import FastVideoArgs @@ -192,7 +193,7 @@ def forward( [fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0], dtype=torch.float32, - device=get_torch_device(), + device=get_local_torch_device(), ).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None) diff --git a/fastvideo/v1/pipelines/stages/encoding.py b/fastvideo/v1/pipelines/stages/encoding.py index 33bd76dca..410dc2aee 100644 --- a/fastvideo/v1/pipelines/stages/encoding.py +++ b/fastvideo/v1/pipelines/stages/encoding.py @@ -7,7 +7,7 @@ import PIL.Image import torch -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.vaes.common import ParallelTiledVAE @@ -49,7 +49,7 @@ def forward( Returns: The batch with encoded outputs. """ - self.vae = self.vae.to(get_torch_device()) + self.vae = self.vae.to(get_local_torch_device()) assert batch.height is not None assert batch.width is not None @@ -65,7 +65,8 @@ def forward( image, vae_scale_factor=self.vae.spatial_compression_ratio, height=batch.height, - width=batch.width).to(get_torch_device(), dtype=torch.float32) + width=batch.width).to(get_local_torch_device(), + dtype=torch.float32) image = image.unsqueeze(2) else: @@ -78,7 +79,7 @@ def forward( batch.num_frames - 1, batch.height, batch.width) ], dim=2) - video_condition = video_condition.to(device=get_torch_device(), + video_condition = video_condition.to(device=get_local_torch_device(), dtype=torch.float32) # Setup VAE precision diff --git a/fastvideo/v1/pipelines/stages/image_encoding.py b/fastvideo/v1/pipelines/stages/image_encoding.py index 27cd03605..1dd3f87ab 100644 --- a/fastvideo/v1/pipelines/stages/image_encoding.py +++ b/fastvideo/v1/pipelines/stages/image_encoding.py @@ -7,7 +7,7 @@ import torch -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger @@ -55,12 +55,12 @@ def forward( The batch with encoded prompt embeddings. """ if fastvideo_args.use_cpu_offload: - self.image_encoder = self.image_encoder.to(get_torch_device()) + self.image_encoder = self.image_encoder.to(get_local_torch_device()) image = batch.pil_image image_inputs = self.image_processor( - images=image, return_tensors="pt").to(get_torch_device()) + images=image, return_tensors="pt").to(get_local_torch_device()) with set_forward_context(current_timestep=0, attn_metadata=None): outputs = self.image_encoder(**image_inputs) image_embeds = outputs.last_hidden_state diff --git a/fastvideo/v1/pipelines/stages/latent_preparation.py b/fastvideo/v1/pipelines/stages/latent_preparation.py index 2926a53bd..2142edc4e 100644 --- a/fastvideo/v1/pipelines/stages/latent_preparation.py +++ b/fastvideo/v1/pipelines/stages/latent_preparation.py @@ -5,7 +5,7 @@ from diffusers.utils.torch_utils import randn_tensor -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -62,7 +62,7 @@ def forward( # Get required parameters dtype = batch.prompt_embeds[0].dtype - device = get_torch_device() + device = get_local_torch_device() generator = batch.generator latents = batch.latents num_frames = latent_num_frames if latent_num_frames is not None else batch.num_frames diff --git a/fastvideo/v1/pipelines/stages/text_encoding.py b/fastvideo/v1/pipelines/stages/text_encoding.py index 4bf4ef2f4..0e5e14125 100644 --- a/fastvideo/v1/pipelines/stages/text_encoding.py +++ b/fastvideo/v1/pipelines/stages/text_encoding.py @@ -5,9 +5,7 @@ This module contains implementations of prompt encoding stages for diffusion pipelines. """ -import torch - -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -62,8 +60,6 @@ def forward( fastvideo_args.pipeline_config.text_encoder_configs, fastvideo_args.pipeline_config.preprocess_text_funcs, fastvideo_args.pipeline_config.postprocess_text_funcs): - if fastvideo_args.use_cpu_offload: - text_encoder = text_encoder.to(get_torch_device()) assert isinstance(batch.prompt, (str, list)) if isinstance(batch.prompt, str): @@ -71,8 +67,9 @@ def forward( texts = [] for prompt_str in batch.prompt: texts.append(preprocess_func(prompt_str)) - text_inputs = tokenizer( - texts, **encoder_config.tokenizer_kwargs).to(get_torch_device()) + text_inputs = tokenizer(texts, + **encoder_config.tokenizer_kwargs).to( + get_local_torch_device()) input_ids = text_inputs["input_ids"] attention_mask = text_inputs["attention_mask"] with set_forward_context(current_timestep=0, attn_metadata=None): @@ -91,8 +88,8 @@ def forward( assert isinstance(batch.negative_prompt, str) negative_text = preprocess_func(batch.negative_prompt) negative_text_inputs = tokenizer( - negative_text, - **encoder_config.tokenizer_kwargs).to(get_torch_device()) + negative_text, **encoder_config.tokenizer_kwargs).to( + get_local_torch_device()) negative_input_ids = negative_text_inputs["input_ids"] negative_attention_mask = negative_text_inputs["attention_mask"] with set_forward_context(current_timestep=0, @@ -110,10 +107,6 @@ def forward( batch.negative_attention_mask.append( negative_attention_mask) - if fastvideo_args.use_cpu_offload: - text_encoder.to('cpu') - torch.cuda.empty_cache() - return batch def verify_input(self, batch: ForwardBatch, diff --git a/fastvideo/v1/pipelines/stages/timestep_preparation.py b/fastvideo/v1/pipelines/stages/timestep_preparation.py index d30134a47..475a0ef31 100644 --- a/fastvideo/v1/pipelines/stages/timestep_preparation.py +++ b/fastvideo/v1/pipelines/stages/timestep_preparation.py @@ -7,7 +7,7 @@ import inspect -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch @@ -45,7 +45,7 @@ def forward( The batch with prepared timesteps. """ scheduler = self.scheduler - device = get_torch_device() + device = get_local_torch_device() num_inference_steps = batch.num_inference_steps timesteps = batch.timesteps sigmas = batch.sigmas diff --git a/fastvideo/v1/pipelines/stepvideo/stepvideo_pipeline.py b/fastvideo/v1/pipelines/stepvideo/stepvideo_pipeline.py index ba3d10172..b25a02e59 100644 --- a/fastvideo/v1/pipelines/stepvideo/stepvideo_pipeline.py +++ b/fastvideo/v1/pipelines/stepvideo/stepvideo_pipeline.py @@ -14,7 +14,7 @@ import torch from huggingface_hub import hf_hub_download -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.encoders.bert import HunyuanClip # type: ignore @@ -78,7 +78,7 @@ def initialize_pipeline(self, fastvideo_args: FastVideoArgs): """ Initialize the pipeline. """ - target_device = get_torch_device() + target_device = get_local_torch_device() llm_dir = os.path.join(self.model_path, "step_llm") clip_dir = os.path.join(self.model_path, "hunyuan_clip") text_enc = self.build_llm(llm_dir, target_device) @@ -129,7 +129,6 @@ def load_modules(self, fastvideo_args: FastVideoArgs) -> Dict[str, Any]: module_name=module_name, component_model_path=component_model_path, transformers_or_diffusers=transformers_or_diffusers, - architecture=architecture, fastvideo_args=fastvideo_args, ) logger.info("Loaded module %s from %s", module_name, diff --git a/fastvideo/v1/tests/encoders/test_clip_encoder.py b/fastvideo/v1/tests/encoders/test_clip_encoder.py index 9a65e87b4..1bf1ffae5 100644 --- a/fastvideo/v1/tests/encoders/test_clip_encoder.py +++ b/fastvideo/v1/tests/encoders/test_clip_encoder.py @@ -6,7 +6,7 @@ import pytest import torch from transformers import AutoConfig - +import gc from fastvideo.models.hunyuan.text_encoder import (load_text_encoder, load_tokenizer) # from fastvideo.v1.models.hunyuan.text_encoder import load_text_encoder, load_tokenizer @@ -16,6 +16,8 @@ from fastvideo.v1.logger import init_logger from fastvideo.v1.utils import maybe_download_model from fastvideo.v1.configs.models.encoders import CLIPTextConfig +from torch.distributed.tensor import DTensor +from torch.testing import assert_close logger = init_logger(__name__) @@ -61,12 +63,11 @@ def test_clip_encoder(): from fastvideo.v1.models.loader.component_loader import TextEncoderLoader loader = TextEncoderLoader() - model2 = loader.load(TEXT_ENCODER_PATH, "", args) + model2 = loader.load(TEXT_ENCODER_PATH, args) # Load the HuggingFace implementation directly # model2 = CLIPTextModel(hf_config) # model2 = model2.to(torch.float16) - model2 = model2.to(device) model2.eval() # Sanity check weights between the two models @@ -78,19 +79,20 @@ def test_clip_encoder(): logger.info("Model1 has %d parameters", len(params1)) logger.info("Model2 has %d parameters", len(params2)) - # Compare a few key parameters - - # weight_diffs = [] - # for (name1, param1), (name2, param2) in zip( - # sorted(params1.items()), sorted(params2.items()) - # ): - # # if len(weight_diffs) < 5: # Just check a few parameters - # max_diff = torch.max(torch.abs(param1 - param2)).item() - # mean_diff = torch.mean(torch.abs(param1 - param2)).item() - # weight_diffs.append((name1, name2, max_diff, mean_diff)) - # logger.info(f"Parameter: {name1} vs {name2}") - # logger.info(f" Max diff: {max_diff}, Mean diff: {mean_diff}") - + for name1, param1 in sorted(params1.items()): + name2 = name1 + skip = False + for param_name, weight_name, shard_id in model2.config.arch_config.stacked_params_mapping: + if weight_name not in name1: + skip = True + # stacked params are more troublesome + if skip: + continue + param2 = params2[name2] + param2 = param2.to_local().to(device) if isinstance(param2, DTensor) else param2.to(device) + assert_close(param1, param2, atol=1e-4, rtol=1e-4) + gc.collect() + torch.cuda.empty_cache() # Load tokenizer tokenizer, _ = load_tokenizer(tokenizer_type="clipL", tokenizer_path=args.model_path, diff --git a/fastvideo/v1/tests/encoders/test_llama_encoder.py b/fastvideo/v1/tests/encoders/test_llama_encoder.py index 9848d8588..a5b15d733 100644 --- a/fastvideo/v1/tests/encoders/test_llama_encoder.py +++ b/fastvideo/v1/tests/encoders/test_llama_encoder.py @@ -5,7 +5,7 @@ import pytest import torch from transformers import AutoConfig - +import gc from fastvideo.models.hunyuan.text_encoder import (load_text_encoder, load_tokenizer) from fastvideo.v1.configs.pipelines import PipelineConfig @@ -15,7 +15,8 @@ from fastvideo.v1.models.loader.component_loader import TextEncoderLoader from fastvideo.v1.utils import maybe_download_model from fastvideo.v1.configs.models.encoders import LlamaConfig - +from torch.distributed.tensor import DTensor +from torch.testing import assert_close logger = init_logger(__name__) os.environ["MASTER_ADDR"] = "localhost" @@ -58,11 +59,10 @@ def test_llama_encoder(): device=device) loader = TextEncoderLoader() device = torch.device("cuda:0") - model2 = loader.load(TEXT_ENCODER_PATH, "", args) + model2 = loader.load(TEXT_ENCODER_PATH, args) # Convert to float16 and move to device # model2 = model2.to(torch.float16) - model2 = model2.to(device) model2.eval() # Sanity check weights between the two models @@ -77,34 +77,28 @@ def test_llama_encoder(): # Compare a few key parameters weight_diffs = [] # check if embed_tokens are the same - print(model1.embed_tokens.weight.shape, model2.embed_tokens.weight.shape) + device = model1.embed_tokens.weight.device assert torch.allclose(model1.embed_tokens.weight, - model2.embed_tokens.weight) + model2.embed_tokens.weight.to_local().to(device) if isinstance(model2.embed_tokens.weight, DTensor) else model2.embed_tokens.weight.to(device)) weights = [ "layers.{}.input_layernorm.weight", "layers.{}.post_attention_layernorm.weight" ] - # for (name1, param1), (name2, param2) in zip( - # sorted(params1.items()), sorted(params2.items()) - # ): - for layer_idx in range(hf_config.num_hidden_layers): - for w in weights: - name1 = w.format(layer_idx) - name2 = w.format(layer_idx) - p1 = params1[name1] - p2 = params2[name2] - # print(type(p2)) - if "gate_up" in name2: - # print("skipping gate_up") - continue - try: - # logger.info(f"Parameter: {name1} vs {name2}") - max_diff = torch.max(torch.abs(p1 - p2)).item() - mean_diff = torch.mean(torch.abs(p1 - p2)).item() - weight_diffs.append((name1, name2, max_diff, mean_diff)) - # logger.info(f" Max diff: {max_diff}, Mean diff: {mean_diff}") - except Exception as e: - logger.info("Error comparing %s and %s: %s", name1, name2, e) + + for name1, param1 in sorted(params1.items()): + name2 = name1 + skip = False + for param_name, weight_name, shard_id in model2.config.arch_config.stacked_params_mapping: + if weight_name not in name1: + skip = True + # stacked params are more troublesome + if skip: + continue + param2 = params2[name2] + param2 = param2.to_local().to(device) if isinstance(param2, DTensor) else param2.to(device) + assert_close(param1, param2, atol=1e-4, rtol=1e-4) + gc.collect() + torch.cuda.empty_cache() tokenizer, _ = load_tokenizer(tokenizer_type="llm", tokenizer_path=TOKENIZER_PATH, diff --git a/fastvideo/v1/tests/encoders/test_t5_encoder.py b/fastvideo/v1/tests/encoders/test_t5_encoder.py index 9ff3c4c8a..e56a3e024 100644 --- a/fastvideo/v1/tests/encoders/test_t5_encoder.py +++ b/fastvideo/v1/tests/encoders/test_t5_encoder.py @@ -4,6 +4,8 @@ import numpy as np import pytest import torch +from torch.distributed.tensor import DTensor +from torch.testing import assert_close from transformers import AutoConfig, AutoTokenizer, UMT5EncoderModel from fastvideo.v1.configs.pipelines import PipelineConfig @@ -41,13 +43,13 @@ def test_t5_encoder(): tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH) - args = FastVideoArgs(model_path=TEXT_ENCODER_PATH, pipeline_config=PipelineConfig(text_encoder_configs=(T5Config(),), text_encoder_precisions=(precision_str,))) + args = FastVideoArgs(model_path=TEXT_ENCODER_PATH, + pipeline_config=PipelineConfig(text_encoder_configs=(T5Config(),), + text_encoder_precisions=(precision_str,)), + pin_cpu_memory=False) loader = TextEncoderLoader() - model2 = loader.load(TEXT_ENCODER_PATH, "", args) - - # Convert to float16 and move to device - # model2 = model2.to(precision) - model2 = model2.to(device) + model2 = loader.load(TEXT_ENCODER_PATH, args) + model2 = model2.to(precision) model2.eval() # Sanity check weights between the two models @@ -64,23 +66,17 @@ def test_t5_encoder(): weights = ["encoder.block.{}.layer.0.layer_norm.weight", "encoder.block.{}.layer.0.SelfAttention.relative_attention_bias.weight", \ "encoder.block.{}.layer.0.SelfAttention.o.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_0.weight", "encoder.block.{}.layer.1.DenseReluDense.wi_1.weight",\ "encoder.block.{}.layer.1.DenseReluDense.wo.weight", \ - "encoder.block.{}.layer.1.layer_norm.weight", "encoder.final_layer_norm.weight", "shared.weight"] + "encoder.block.{}.layer.1.layer_norm.weight", "encoder.final_layer_norm.weight"] + for idx in range(hf_config.num_hidden_layers): for w in weights: name1 = w.format(idx) name2 = w.format(idx) p1 = params1[name1] p2 = params2[name2] - assert p1.dtype == p2.dtype - try: - logger.info("Parameter: %s vs %s", name1, name2) - max_diff = torch.max(torch.abs(p1 - p2)).item() - mean_diff = torch.mean(torch.abs(p1 - p2)).item() - weight_diffs.append((name1, name2, max_diff, mean_diff)) - logger.info(" Max diff: %s, Mean diff: %s", max_diff, - mean_diff) - except Exception as e: - logger.info("Error comparing %s and %s: %s", name1, name2, e) + p2 = (p2.to_local() if isinstance(p2, DTensor) else p2).to(p1) + assert_close(p1, p2, atol=1e-4, rtol=1e-4) + # Test with some sample prompts prompts = [ @@ -134,7 +130,7 @@ def test_t5_encoder(): max_diff_hidden.item()) logger.info("Mean difference in last hidden states: %s", mean_diff_hidden.item()) - + logger.info("Max memory allocated: %s GB", torch.cuda.max_memory_allocated() / 1024**3) # Check if outputs are similar (allowing for small numerical differences) assert mean_diff_hidden < 1e-4, \ f"Hidden states differ significantly: mean diff = {mean_diff_hidden.item()}" diff --git a/fastvideo/v1/tests/transformers/test_hunyuanvideo.py b/fastvideo/v1/tests/transformers/test_hunyuanvideo.py index 24e08cdb3..73eb39d47 100644 --- a/fastvideo/v1/tests/transformers/test_hunyuanvideo.py +++ b/fastvideo/v1/tests/transformers/test_hunyuanvideo.py @@ -80,7 +80,10 @@ def test_hunyuanvideo_distributed(): # Initialize with identical weights model = initialize_identical_weights(model, seed=42) - shard_model(model, cpu_offload=False, reshard_after_forward=True) + shard_model(model, cpu_offload=True, + reshard_after_forward=True, + fsdp_shard_conditions=model._fsdp_shard_conditions + ) for n, p in chain(model.named_parameters(), model.named_buffers()): if p.is_meta: raise RuntimeError( diff --git a/fastvideo/v1/tests/transformers/test_hunyuanvideo_load.py b/fastvideo/v1/tests/transformers/test_hunyuanvideo_load.py index e3cdde2df..f36342fa3 100644 --- a/fastvideo/v1/tests/transformers/test_hunyuanvideo_load.py +++ b/fastvideo/v1/tests/transformers/test_hunyuanvideo_load.py @@ -67,7 +67,7 @@ def test_hunyuanvideo_distributed(): args.device = torch.device(f"cuda:{LOCAL_RANK}") loader = TransformerLoader() - model = loader.load(TRANSFORMER_PATH, "", args) + model = loader.load(TRANSFORMER_PATH, args) model.eval() diff --git a/fastvideo/v1/tests/transformers/test_wanvideo.py b/fastvideo/v1/tests/transformers/test_wanvideo.py index 422855c75..403fd6e61 100644 --- a/fastvideo/v1/tests/transformers/test_wanvideo.py +++ b/fastvideo/v1/tests/transformers/test_wanvideo.py @@ -39,7 +39,7 @@ def test_wan_transformer(): args.device = device loader = TransformerLoader() - model2 = loader.load(TRANSFORMER_PATH, "", args).to(dtype=precision) + model2 = loader.load(TRANSFORMER_PATH, args).to(dtype=precision) model1 = WanTransformer3DModel.from_pretrained( TRANSFORMER_PATH, device=device, diff --git a/fastvideo/v1/tests/vaes/test_hunyuan_vae.py b/fastvideo/v1/tests/vaes/test_hunyuan_vae.py index 038135ce7..2f6df0368 100644 --- a/fastvideo/v1/tests/vaes/test_hunyuan_vae.py +++ b/fastvideo/v1/tests/vaes/test_hunyuan_vae.py @@ -45,7 +45,7 @@ def test_hunyuan_vae(): args.device = device loader = VAELoader() - model = loader.load(VAE_PATH, "", args) + model = loader.load(VAE_PATH, args) model.enable_tiling(tile_sample_min_height=32, tile_sample_min_width=32, diff --git a/fastvideo/v1/tests/vaes/test_wan_vae.py b/fastvideo/v1/tests/vaes/test_wan_vae.py index 35cb995fd..813cc5ab0 100644 --- a/fastvideo/v1/tests/vaes/test_wan_vae.py +++ b/fastvideo/v1/tests/vaes/test_wan_vae.py @@ -34,7 +34,7 @@ def test_wan_vae(): args.device = device loader = VAELoader() - model2 = loader.load(VAE_PATH, "", args) + model2 = loader.load(VAE_PATH, args) assert model2.use_feature_cache # Default to use the original WanVAE algorithm model1 = AutoencoderKLWan.from_pretrained( diff --git a/fastvideo/v1/training/training_pipeline.py b/fastvideo/v1/training/training_pipeline.py index 60bd34c7f..0596f33f4 100644 --- a/fastvideo/v1/training/training_pipeline.py +++ b/fastvideo/v1/training/training_pipeline.py @@ -24,8 +24,9 @@ from fastvideo.v1.dataset import build_parquet_map_style_dataloader from fastvideo.v1.dataset.dataloader.schema import ( pyarrow_schema_t2v, pyarrow_schema_t2v_validation) -from fastvideo.v1.distributed import (cleanup_dist_env_and_memory, get_sp_group, - get_torch_device, get_world_group) +from fastvideo.v1.distributed import (cleanup_dist_env_and_memory, + get_local_torch_device, get_sp_group, + get_world_group) from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs from fastvideo.v1.forward_context import set_forward_context from fastvideo.v1.logger import init_logger @@ -67,8 +68,8 @@ def set_schemas(self) -> None: def initialize_training_pipeline(self, training_args: TrainingArgs): logger.info("Initializing training pipeline...") + self.device = get_local_torch_device() self.training_args = training_args - self.device = get_torch_device() world_group = get_world_group() self.world_size = world_group.world_size self.global_rank = world_group.rank @@ -183,12 +184,12 @@ def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: encoder_attention_mask = batch['text_attention_mask'] infos = batch['info_list'] - training_batch.latents = latents.to(get_torch_device(), + training_batch.latents = latents.to(get_local_torch_device(), dtype=torch.bfloat16) training_batch.encoder_hidden_states = encoder_hidden_states.to( - get_torch_device(), dtype=torch.bfloat16) + get_local_torch_device(), dtype=torch.bfloat16) training_batch.encoder_attention_mask = encoder_attention_mask.to( - get_torch_device(), dtype=torch.bfloat16) + get_local_torch_device(), dtype=torch.bfloat16) training_batch.infos = infos return training_batch @@ -280,7 +281,7 @@ def _build_input_kwargs(self, "encoder_hidden_states": training_batch.encoder_hidden_states, "timestep": - training_batch.timesteps.to(get_torch_device(), + training_batch.timesteps.to(get_local_torch_device(), dtype=torch.bfloat16), "encoder_attention_mask": training_batch.encoder_attention_mask, @@ -560,8 +561,9 @@ def _prepare_validation_inputs( prompt_embeds = validation_batch['text_embedding'] prompt_attention_mask = validation_batch['text_attention_mask'] - prompt_embeds = prompt_embeds.to(get_torch_device()) - prompt_attention_mask = prompt_attention_mask.to(get_torch_device()) + prompt_embeds = prompt_embeds.to(get_local_torch_device()) + prompt_attention_mask = prompt_attention_mask.to( + get_local_torch_device()) # Calculate sizes latents_size = [(sampling_param.num_frames - 1) // 4 + 1, diff --git a/fastvideo/v1/training/wan_i2v_training_pipeline.py b/fastvideo/v1/training/wan_i2v_training_pipeline.py index b58cfa26a..1c5475e75 100644 --- a/fastvideo/v1/training/wan_i2v_training_pipeline.py +++ b/fastvideo/v1/training/wan_i2v_training_pipeline.py @@ -8,7 +8,7 @@ from fastvideo.v1.configs.sample import SamplingParam from fastvideo.v1.dataset.dataloader.schema import ( pyarrow_schema_i2v, pyarrow_schema_i2v_validation) -from fastvideo.v1.distributed import get_torch_device +from fastvideo.v1.distributed import get_local_torch_device from fastvideo.v1.fastvideo_args import FastVideoArgs, TrainingArgs from fastvideo.v1.logger import init_logger from fastvideo.v1.models.schedulers.scheduling_flow_unipc_multistep import ( @@ -85,15 +85,17 @@ def _get_next_batch(self, training_batch: TrainingBatch) -> TrainingBatch: pil_image = batch['pil_image'] infos = batch['info_list'] - training_batch.latents = latents.to(get_torch_device(), + training_batch.latents = latents.to(get_local_torch_device(), dtype=torch.bfloat16) training_batch.encoder_hidden_states = encoder_hidden_states.to( - get_torch_device(), dtype=torch.bfloat16) + get_local_torch_device(), dtype=torch.bfloat16) training_batch.encoder_attention_mask = encoder_attention_mask.to( - get_torch_device(), dtype=torch.bfloat16) - training_batch.preprocessed_image = pil_image.to(get_torch_device()) - training_batch.image_embeds = clip_features.to(get_torch_device()) - training_batch.image_latents = image_latents.to(get_torch_device()) + get_local_torch_device(), dtype=torch.bfloat16) + training_batch.preprocessed_image = pil_image.to( + get_local_torch_device()) + training_batch.image_embeds = clip_features.to(get_local_torch_device()) + training_batch.image_latents = image_latents.to( + get_local_torch_device()) training_batch.infos = infos return training_batch @@ -112,8 +114,8 @@ def _prepare_dit_inputs(self, training_batch = super()._prepare_dit_inputs(training_batch) assert isinstance(training_batch.image_latents, torch.Tensor) - image_latents = training_batch.image_latents.to(get_torch_device(), - dtype=torch.bfloat16) + image_latents = training_batch.image_latents.to( + get_local_torch_device(), dtype=torch.bfloat16) training_batch.noisy_model_input = torch.cat( [training_batch.noisy_model_input, image_latents], dim=1) @@ -132,7 +134,8 @@ def _build_input_kwargs(self, # Image Embeds for conditioning image_embeds = training_batch.image_embeds assert torch.isnan(image_embeds).sum() == 0 - image_embeds = image_embeds.to(get_torch_device(), dtype=torch.bfloat16) + image_embeds = image_embeds.to(get_local_torch_device(), + dtype=torch.bfloat16) encoder_hidden_states_image = image_embeds # NOTE: noisy_model_input already contains concatenated image_latents from _prepare_dit_inputs @@ -142,7 +145,7 @@ def _build_input_kwargs(self, "encoder_hidden_states": training_batch.encoder_hidden_states, "timestep": - training_batch.timesteps.to(get_torch_device(), + training_batch.timesteps.to(get_local_torch_device(), dtype=torch.bfloat16), "encoder_attention_mask": training_batch.encoder_attention_mask, @@ -166,9 +169,9 @@ def _prepare_validation_inputs( infos = validation_batch['info_list'] prompt = infos[0]['prompt'] - prompt_embeds = embeddings.to(get_torch_device()) - prompt_attention_mask = masks.to(get_torch_device()) - clip_features = clip_features.to(get_torch_device()) + prompt_embeds = embeddings.to(get_local_torch_device()) + prompt_attention_mask = masks.to(get_local_torch_device()) + clip_features = clip_features.to(get_local_torch_device()) # Calculate sizes latents_size = [(sampling_param.num_frames - 1) // 4 + 1,