-
Notifications
You must be signed in to change notification settings - Fork 794
Description
🐛 Describe the bug
The dequantize_per_channel operator produces incorrect output when using the channels last dim order. In my example case, the input has shape [2,2,3,3].
The scale of the input is [ 0.0016989057185128331, 0.001776964869350195 ]
The input data is [ -20,-59,-22,-40,127,-108,-57,117,24,-103,48,-110,80,15,-10,-75,-77,-46,-12,-66,35,-87,-50,-80,12,-127,107,91,115,-54,-6,-6,-41,46,42,-83 ]
The output is [ -0.033978,0.000000,-0.037376,0.000000,0.215761,0.000000,-0.096838,0.000000,0.040774,0.000000,0.081547,0.000000,0.135912,0.000000,-0.016989,0.000000,-0.130816,0.000000,-0.021324,0.000000,0.062194,0.000000,-0.088848,0.000000,0.021324,0.000000,0.190135,0.000000,0.204351,0.000000,-0.010662,0.000000,-0.072856,0.000000,0.074633,0.000000 ]
Every other output is 0, which is clearly incorrect.
The model is set to use the channels last memory format (dim order), so it doesn't use the dequantize_per_channel_optimized() implementation in kernels/quantized/cpu/op_dequantize.cpp on line 403.
The expected output was [ -0.03397811, -0.10023544, -0.03737593, -0.06795623, 0.21576103, -0.18348182, -0.09683763, 0.19877197, 0.04077374, -0.18302738, 0.08529431, -0.19546614, 0.14215719, 0.02665447, -0.01776965, -0.13327237, -0.13682629, -0.08174038, -0.02038687, -0.11212778, 0.0594617 , -0.1478048 , -0.08494529, -0.13591246, 0.02038687, -0.21576103, 0.18178291, 0.1617038 , 0.20435096, -0.0959561 , -0.01066179, -0.01066179, -0.07285556, 0.08174038, 0.07463252, -0.14748808 ]
The example model is attached. The operator described above is the one which de-quantizes the weights.
Versions
Collecting environment information...
PyTorch version: 2.10.0.dev20251025+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-2ubuntu1~20.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.10
Libc version: glibc-2.31
Python version: 3.10.12 (main, Oct 1 2025, 11:13:47) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 16
On-line CPU(s) list: 0-15
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 186
Model name: 13th Gen Intel(R) Core(TM) i5-1350P
Stepping: 2
CPU MHz: 2188.810
BogoMIPS: 4377.62
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 384 KiB
L1i cache: 256 KiB
L2 cache: 10 MiB
L3 cache: 12 MiB
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Mitigation; Clear Register File
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves avx_vnni vnmi umip waitpkg gfni vaes vpclmulqdq rdpid movdiri movdir64b fsrm md_clear serialize flush_l1d arch_capabilities
Versions of relevant libraries:
[pip3] executorch==1.1.0a0+a069bba
[pip3] flake8==6.1.0
[pip3] flake8-breakpoint==1.1.0
[pip3] flake8-bugbear==24.4.26
[pip3] flake8-comprehensions==3.14.0
[pip3] flake8-plugin-utils==1.3.3
[pip3] flake8-pyi==23.5.0
[pip3] mypy==1.14.1
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.0.0
[pip3] optree==0.18.0
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.10.0.dev20251025+cpu
[pip3] torchao==0.14.0+git01849b2b1
[pip3] torchaudio==2.10.0.dev20251025+cpu
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.6.1
[pip3] torchvision==0.25.0.dev20251025+cpu
[conda] Could not collect