Skip to content

Commit 8ad065b

Browse files
committed
introduce NeMoWanPipeline
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 320f92b commit 8ad065b

File tree

1 file changed

+36
-24
lines changed

1 file changed

+36
-24
lines changed

dfm/src/Automodel/_diffusers/auto_diffusion_pipeline.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import copy
1516
import logging
1617
import os
1718
from typing import Any, Dict, Iterable, Optional, Tuple
1819

1920
import torch
2021
import torch.nn as nn
2122
from Automodel.distributed.dfm_parallelizer import WanParallelizationStrategy
22-
from diffusers import DiffusionPipeline
23+
from diffusers import DiffusionPipeline, WanPipeline
2324
from nemo_automodel.components.distributed import parallelizer
2425
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
2526
from nemo_automodel.shared.utils import dtype_from_str
@@ -155,30 +156,48 @@ def from_pretrained(
155156
setattr(pipe, comp_name, parallel_module)
156157
return pipe, created_managers
157158

159+
160+
class NeMoWanPipeline(WanPipeline):
161+
def __init__(self, *args, **kwargs):
162+
super().__init__(*args, **kwargs)
163+
164+
@classmethod
165+
def from_pretrained(cls, *args, **kwargs):
166+
return NeMoAutoDiffusionPipeline.from_pretrained(*args, **kwargs)
167+
158168
@classmethod
159169
def from_config(
160170
cls,
161-
pretrained_model_name_or_path: str,
162-
*model_args,
171+
model_id,
172+
torch_dtype: torch.dtype = torch.bfloat16,
173+
config: dict = None,
163174
parallel_scheme: Optional[Dict[str, Dict[str, Any]]] = None,
164175
device: Optional[torch.device] = None,
165-
torch_dtype: Any = "auto",
166176
move_to_device: bool = True,
167-
load_for_training: bool = False,
168177
components_to_load: Optional[Iterable[str]] = None,
169-
**kwargs,
170-
) -> tuple[DiffusionPipeline, Dict[str, FSDP2Manager]]:
171-
config = WanTransformer3DModel.from_pretrained(
172-
pretrained_model_name_or_path,
173-
subfolder="transformer",
178+
):
179+
# Load just the config
180+
from diffusers import WanTransformer3DModel
181+
182+
if model_id is not None:
183+
transformer = WanTransformer3DModel.from_pretrained(
184+
model_id,
185+
subfolder="transformer",
186+
torch_dtype=torch.bfloat16,
187+
)
188+
189+
# Get config and reinitialize with random weights
190+
config = copy.deepcopy(transformer.config)
191+
del transformer
192+
193+
# Initialize with random weights
194+
transformer = WanTransformer3DModel.from_config(config)
195+
196+
# Load pipeline with random transformer
197+
pipe = WanPipeline.from_pretrained(
198+
model_id,
199+
transformer=transformer,
174200
torch_dtype=torch_dtype,
175-
**kwargs,
176-
)
177-
pipe: DiffusionPipeline = DiffusionPipeline.from_config(
178-
config,
179-
*model_args,
180-
torch_dtype=torch_dtype,
181-
**kwargs,
182201
)
183202
# Decide device
184203
dev = _choose_device(device)
@@ -190,13 +209,6 @@ def from_config(
190209
logger.info("[INFO] Moving module: %s to device/dtype", name)
191210
_move_module_to_device(module, dev, torch_dtype)
192211

193-
# If loading for training, ensure the target module parameters are trainable
194-
if load_for_training:
195-
for name, module in _iter_pipeline_modules(pipe):
196-
if not components_to_load or name in components_to_load:
197-
logger.info("[INFO] Ensuring params trainable: %s", name)
198-
_ensure_params_trainable(module, module_name=name)
199-
200212
# Use per-component FSDP2Manager init-args to parallelize components
201213
created_managers: Dict[str, FSDP2Manager] = {}
202214
if parallel_scheme is not None:

0 commit comments

Comments
 (0)