Skip to content

Commit 83c3229

Browse files
authored
Fix nvfuser warning (#2694)
1 parent 1d8965c commit 83c3229

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,20 @@
7474
DTENSOR_SUPPORTED_VERSION = LooseVersion("0.2.28")
7575
if nvfuser_version() >= DIRECT_BINDINGS_SUPPORTED_VERSION:
7676
import nvfuser_direct as nvfuser
77-
from nvfuser_direct import DataType, FusionDefinition, multidevice, ParallelType, execute_with_dtensors
77+
from nvfuser_direct import (
78+
DataType,
79+
FusionDefinition,
80+
multidevice,
81+
ParallelType,
82+
execute_with_dtensors,
83+
compute_tensor_descriptor as nv_compute_td,
84+
)
7885
else:
7986
if nvfuser_version() >= DTENSOR_SUPPORTED_VERSION:
8087
from nvfuser_direct import FusionDefinition as DirectFusionDefinition
8188
from nvfuser_direct import multidevice, ParallelType, execute_with_dtensors
8289
import nvfuser
83-
from nvfuser import DataType, FusionDefinition
90+
from nvfuser import DataType, FusionDefinition, compute_tensor_descriptor as nv_compute_td
8491

8592
#
8693
# Helper functions
@@ -483,8 +490,6 @@ def compute_contiguity(
483490
Returns:
484491
Tuple[Tuple[bool, ...], Tuple[int, ...]]: The contiguity and stride_order
485492
"""
486-
from nvfuser import compute_tensor_descriptor as nv_compute_td
487-
488493
return tuple(tuple(x) for x in nv_compute_td(shape, stride))
489494

490495

0 commit comments

Comments
 (0)