|
23 | 23 | import torch |
24 | 24 |
|
25 | 25 | from ..utils import get_logger, is_accelerate_available |
| 26 | +from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS |
26 | 27 | from .hooks import HookRegistry, ModelHook |
27 | 28 |
|
28 | 29 |
|
|
39 | 40 | _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" |
40 | 41 | _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" |
41 | 42 | _GROUP_ID_LAZY_LEAF = "lazy_leafs" |
42 | | -_SUPPORTED_PYTORCH_LAYERS = ( |
43 | | - torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, |
44 | | - torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, |
45 | | - torch.nn.Linear, |
46 | | - # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX |
47 | | - # because of double invocation of the same norm layer in CogVideoXLayerNorm |
48 | | -) |
49 | 43 | # fmt: on |
50 | 44 |
|
51 | 45 |
|
@@ -367,7 +361,8 @@ def __init__(self): |
367 | 361 | def initialize_hook(self, module): |
368 | 362 | def make_execution_order_update_callback(current_name, current_submodule): |
369 | 363 | def callback(): |
370 | | - logger.debug(f"Adding {current_name} to the execution order") |
| 364 | + if not torch.compiler.is_compiling(): |
| 365 | + logger.debug(f"Adding {current_name} to the execution order") |
371 | 366 | self.execution_order.append((current_name, current_submodule)) |
372 | 367 |
|
373 | 368 | return callback |
@@ -404,12 +399,13 @@ def post_forward(self, module, output): |
404 | 399 | # if the missing layers end up being executed in the future. |
405 | 400 | if execution_order_module_names != self._layer_execution_tracker_module_names: |
406 | 401 | unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names) |
407 | | - logger.warning( |
408 | | - "It seems like some layers were not executed during the forward pass. This may lead to problems when " |
409 | | - "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please " |
410 | | - "make sure that all layers are executed during the forward pass. The following layers were not executed:\n" |
411 | | - f"{unexecuted_layers=}" |
412 | | - ) |
| 402 | + if not torch.compiler.is_compiling(): |
| 403 | + logger.warning( |
| 404 | + "It seems like some layers were not executed during the forward pass. This may lead to problems when " |
| 405 | + "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please " |
| 406 | + "make sure that all layers are executed during the forward pass. The following layers were not executed:\n" |
| 407 | + f"{unexecuted_layers=}" |
| 408 | + ) |
413 | 409 |
|
414 | 410 | # Remove the layer execution tracker hooks from the submodules |
415 | 411 | base_module_registry = module._diffusers_hook |
@@ -437,7 +433,8 @@ def post_forward(self, module, output): |
437 | 433 | for i in range(num_executed - 1): |
438 | 434 | name1, _ = self.execution_order[i] |
439 | 435 | name2, _ = self.execution_order[i + 1] |
440 | | - logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}") |
| 436 | + if not torch.compiler.is_compiling(): |
| 437 | + logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}") |
441 | 438 | group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group |
442 | 439 | group_offloading_hooks[i].next_group.onload_self = False |
443 | 440 |
|
@@ -680,7 +677,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff |
680 | 677 | # Create module groups for leaf modules and apply group offloading hooks |
681 | 678 | modules_with_group_offloading = set() |
682 | 679 | for name, submodule in module.named_modules(): |
683 | | - if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): |
| 680 | + if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): |
684 | 681 | continue |
685 | 682 | group = ModuleGroup( |
686 | 683 | modules=[submodule], |
|
0 commit comments