forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 75
Open
Description
๐ Describe the bug
Following commit to release/2.7 from two weeks ago (likely) breaks torch.compile for some diffusers models as dynamo cannot handle the FrozenDict class which diffusers uses for storing and accessing model-specific configurations, causing the following error during forward method:
[rank3]: source_value = DictSubclassGetItemSource(base, source_key)
[rank3]: ^^^^
[rank3]: torch._dynamo.exc.InternalTorchDynamoError: NameError: name 'base' is not defined
Below a simple reproducer:
"""
Bug: NameError: name 'base' is not defined
Location: torch/_dynamo/variables/builder.py line 1278
"""
import torch
from collections import OrderedDict
# FrozenDict from diffusers
class FrozenDict(OrderedDict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in self.items():
setattr(self, key, value)
self.__frozen = True
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __setattr__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setattr__(name, value)
def __setitem__(self, name, value):
if hasattr(self, "__frozen") and self.__frozen:
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
super().__setitem__(name, value)
class SimpleModel(torch.nn.Module):
"""Simple model with FrozenDict config"""
def __init__(self):
super().__init__()
self.config = FrozenDict([('patch_size', 2), ('hidden_dim', 128)])
self.linear = torch.nn.Linear(10, 10)
def forward(self, x):
# Access FrozenDict in forward - triggers bug when compiled
p = self.config['patch_size'] # BUG HERE
x = self.linear(x)
return x * p
def simple_function(x):
"""Simple function that accesses FrozenDict"""
config = FrozenDict([('patch_size', 2), ('value', 42)])
p = config['patch_size']
return x * p
if __name__ == "__main__":
print("Testing TorchDynamo FrozenDict bug.")
print(f"PyTorch version: {torch.__version__}\n")
# Test 1: Function - Eager mode
print("Test 1: Simple function - Eager mode (no compile)")
try:
result = simple_function(torch.tensor(5.0))
print(f"Result: {result}\n")
except Exception as e:
print(f"Error: {e}\n")
# Test 2: Function - Compiled
print("Test 2: Simple function - Compiled with torch.compile")
compiled_fn = torch.compile(simple_function)
try:
result = compiled_fn(torch.tensor(5.0))
print(f"Result: {result}\n")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}\n")
# Test 3: Model - Eager mode
print("Test 3: nn.Module - Eager mode (no compile)")
model = SimpleModel()
x = torch.randn(2, 10)
try:
output = model(x)
print(f"Output shape: {output.shape}\n")
except Exception as e:
print(f"Error: {e}\n")
# Test 4: Model - Compiled
print("Test 4: nn.Module - Compiled with torch.compile")
model = SimpleModel()
compiled_model = torch.compile(model)
x = torch.randn(2, 10)
try:
output = compiled_model(x)
print(f"Output shape: {output.shape}\n")
except Exception as e:
print(f"Error: {type(e).__name__}: {e}\n")
This reproducer works with https://github.com/ROCm/pytorch/tree/99ccf24, but yields the error with https://github.com/ROCm/pytorch/tree/a033df6.
Versions
Error with torch.compile:
Collecting environment information...
PyTorch version: 2.7.1+rocm7.9.0rc20250930
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.1.25392-0bdf9d75da
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.11.13 | packaged by conda-forge | (main, Jun 4 2025, 14:48:23) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-156-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.1.25392
MIOpen runtime version: 3.5.1
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9534 64-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3719.5830
CPU min MHz: 1500.0000
BogoMIPS: 4892.50
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 128 MiB (128 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] onnx==1.19.0
[pip3] onnx2torch==1.5.15
[pip3] pytorch-triton-rocm==3.3.1+rocm7.9.0rc20250930
[pip3] torch==2.7.1+rocm7.9.0rc20250930
[pip3] torchaudio==2.7.1a0+rocm7.9.0rc20250930
[pip3] torchvision==0.22.1+rocm7.9.0rc20250930
[conda] Could not collect
Functioning environment:
Collecting environment information...
PyTorch version: 2.7.1+rocm7.9.0rc20250925
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.1.25384-0167af32bd
OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04.2) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35
Python version: 3.11.13 | packaged by conda-forge | (main, Jun 4 2025, 14:48:23) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-156-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Instinct MI300X (gfx942:sramecc+:xnack-)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.1.25384
MIOpen runtime version: 3.5.0
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9534 64-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU max MHz: 3719.5830
CPU min MHz: 1500.0000
BogoMIPS: 4892.50
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 128 MiB (128 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] onnx==1.19.0
[pip3] onnx2torch==1.5.15
[pip3] pytorch-triton-rocm==3.3.1+rocm7.9.0rc20250925
[pip3] torch==2.7.1+rocm7.9.0rc20250925
[pip3] torchaudio==2.7.1a0+rocm7.9.0rc20250925
[pip3] torchvision==0.22.1+rocm7.9.0rc20250925
[conda] Could not collect
Metadata
Metadata
Assignees
Labels
No labels