|
1 | 1 | import subprocess |
2 | 2 | from typing import Optional |
3 | 3 | import warnings |
| 4 | +import os |
4 | 5 |
|
5 | 6 | import torch |
6 | 7 |
|
@@ -55,7 +56,7 @@ def _ipex_xpu_version_prereq(major, minor): |
55 | 56 |
|
56 | 57 | def _maybe_torch_compile(func): |
57 | 58 | # torch.compile requires g++ and pytorch >= 2.0 |
58 | | - if gxx_available and _torch_version_prereq(2, 0): |
| 59 | + if gxx_available and _torch_version_prereq(2, 0) and os.getenv('PT_HPU_LAZY_MODE',1)==0: |
59 | 60 | options = {} |
60 | 61 | # fx_graph_cache requires pytorch >= 2.2 |
61 | 62 | if _torch_version_prereq(2, 2): |
@@ -277,7 +278,7 @@ def mm_dequant_impl( |
277 | 278 | } |
278 | 279 |
|
279 | 280 |
|
280 | | -# @_maybe_torch_compile |
| 281 | +@_maybe_torch_compile |
281 | 282 | def quantize_4bit_impl( |
282 | 283 | A: Tensor, |
283 | 284 | absmax: Tensor = None, |
@@ -342,7 +343,7 @@ def quantize_4bit_impl( |
342 | 343 | scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) |
343 | 344 | scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) |
344 | 345 | # map [-1, 1] to nf4/fp4 |
345 | | - out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=scaled_A.device) |
| 346 | + out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device) |
346 | 347 | if quant_type == "nf4": |
347 | 348 | for i in range(len(NF4_QUANT_TABLE)): |
348 | 349 | out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i |
@@ -372,8 +373,7 @@ def quantize_4bit_impl( |
372 | 373 |
|
373 | 374 | return out.unsqueeze(0), state |
374 | 375 |
|
375 | | - |
376 | | -#@_maybe_torch_compile |
| 376 | +@_maybe_torch_compile |
377 | 377 | def dequantize_4bit_impl( |
378 | 378 | A: Tensor, |
379 | 379 | quant_state=None, |
|
0 commit comments