1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import copy
1516import logging
1617import os
1718from typing import Any , Dict , Iterable , Optional , Tuple
1819
1920import torch
2021import torch .nn as nn
2122from Automodel .distributed .dfm_parallelizer import WanParallelizationStrategy
22- from diffusers import DiffusionPipeline
23+ from diffusers import DiffusionPipeline , WanPipeline
2324from nemo_automodel .components .distributed import parallelizer
2425from nemo_automodel .components .distributed .fsdp2 import FSDP2Manager
2526from nemo_automodel .shared .utils import dtype_from_str
@@ -155,30 +156,48 @@ def from_pretrained(
155156 setattr (pipe , comp_name , parallel_module )
156157 return pipe , created_managers
157158
159+
160+ class NeMoWanPipeline (WanPipeline ):
161+ def __init__ (self , * args , ** kwargs ):
162+ super ().__init__ (* args , ** kwargs )
163+
164+ @classmethod
165+ def from_pretrained (cls , * args , ** kwargs ):
166+ return NeMoAutoDiffusionPipeline .from_pretrained (* args , ** kwargs )
167+
158168 @classmethod
159169 def from_config (
160170 cls ,
161- pretrained_model_name_or_path : str ,
162- * model_args ,
171+ model_id ,
172+ torch_dtype : torch .dtype = torch .bfloat16 ,
173+ config : dict = None ,
163174 parallel_scheme : Optional [Dict [str , Dict [str , Any ]]] = None ,
164175 device : Optional [torch .device ] = None ,
165- torch_dtype : Any = "auto" ,
166176 move_to_device : bool = True ,
167- load_for_training : bool = False ,
168177 components_to_load : Optional [Iterable [str ]] = None ,
169- ** kwargs ,
170- ) -> tuple [DiffusionPipeline , Dict [str , FSDP2Manager ]]:
171- config = WanTransformer3DModel .from_pretrained (
172- pretrained_model_name_or_path ,
173- subfolder = "transformer" ,
178+ ):
179+ # Load just the config
180+ from diffusers import WanTransformer3DModel
181+
182+ if model_id is not None :
183+ transformer = WanTransformer3DModel .from_pretrained (
184+ model_id ,
185+ subfolder = "transformer" ,
186+ torch_dtype = torch .bfloat16 ,
187+ )
188+
189+ # Get config and reinitialize with random weights
190+ config = copy .deepcopy (transformer .config )
191+ del transformer
192+
193+ # Initialize with random weights
194+ transformer = WanTransformer3DModel .from_config (config )
195+
196+ # Load pipeline with random transformer
197+ pipe = WanPipeline .from_pretrained (
198+ model_id ,
199+ transformer = transformer ,
174200 torch_dtype = torch_dtype ,
175- ** kwargs ,
176- )
177- pipe : DiffusionPipeline = DiffusionPipeline .from_config (
178- config ,
179- * model_args ,
180- torch_dtype = torch_dtype ,
181- ** kwargs ,
182201 )
183202 # Decide device
184203 dev = _choose_device (device )
@@ -190,13 +209,6 @@ def from_config(
190209 logger .info ("[INFO] Moving module: %s to device/dtype" , name )
191210 _move_module_to_device (module , dev , torch_dtype )
192211
193- # If loading for training, ensure the target module parameters are trainable
194- if load_for_training :
195- for name , module in _iter_pipeline_modules (pipe ):
196- if not components_to_load or name in components_to_load :
197- logger .info ("[INFO] Ensuring params trainable: %s" , name )
198- _ensure_params_trainable (module , module_name = name )
199-
200212 # Use per-component FSDP2Manager init-args to parallelize components
201213 created_managers : Dict [str , FSDP2Manager ] = {}
202214 if parallel_scheme is not None :
0 commit comments