67
67
from thunder .executors .nvfuserex import nvfuser_version
68
68
69
69
70
- DTENSOR_SUPPORTED_VERSION = LooseVersion ("0.2.28" )
71
- if nvfuser_version () >= DTENSOR_SUPPORTED_VERSION :
72
- import nvfuser_direct as nvfd
73
- from nvfuser_direct import FusionDefinition as DirectFusionDefinition
74
-
75
70
# NOTE This impl file is here because nvFuser may not be available, so it's imported conditionally
76
71
# by nvfuserex.py when nvFuser is available.
77
- import nvfuser
78
- from nvfuser import DataType , FusionDefinition
72
+
73
+ DIRECT_BINDINGS_SUPPORTED_VERSION = LooseVersion ("0.2.34" )
74
+ DTENSOR_SUPPORTED_VERSION = LooseVersion ("0.2.28" )
75
+ if nvfuser_version () >= DIRECT_BINDINGS_SUPPORTED_VERSION :
76
+ import nvfuser_direct as nvfuser
77
+ from nvfuser_direct import DataType , FusionDefinition , multidevice , ParallelType , execute_with_dtensors
78
+ else :
79
+ if nvfuser_version () >= DTENSOR_SUPPORTED_VERSION :
80
+ from nvfuser_direct import FusionDefinition as DirectFusionDefinition
81
+ from nvfuser_direct import multidevice , ParallelType , execute_with_dtensors
82
+ import nvfuser
83
+ from nvfuser import DataType , FusionDefinition
79
84
80
85
#
81
86
# Helper functions
@@ -258,9 +263,9 @@ def multidevice_schedule(fd: FusionDefinition, in_dtensors: list[Proxy]) -> None
258
263
259
264
# nvfuser's DeviceMesh supports torch.Tensor since 0.2.30
260
265
if nvfuser_version () >= LooseVersion ("0.2.30" ):
261
- mesh = nvfd . multidevice .DeviceMesh (in_dtensor .device_mesh .mesh )
266
+ mesh = multidevice .DeviceMesh (in_dtensor .device_mesh .mesh )
262
267
else :
263
- mesh = nvfd . multidevice .DeviceMesh (in_dtensor .device_mesh .mesh .tolist ())
268
+ mesh = multidevice .DeviceMesh (in_dtensor .device_mesh .mesh .tolist ())
264
269
265
270
in_tv .set_device_mesh (mesh )
266
271
@@ -273,7 +278,7 @@ def multidevice_schedule(fd: FusionDefinition, in_dtensors: list[Proxy]) -> None
273
278
if placement .is_shard ():
274
279
dim = cast (Shard , placement ).dim
275
280
in_tv .split (dim , mesh .size , inner_split = False )
276
- in_tv .axis (dim ).parallelize (nvfd . ParallelType .mesh_x )
281
+ in_tv .axis (dim ).parallelize (ParallelType .mesh_x )
277
282
in_tv .set_allocation_domain (in_tv .get_loop_domain (), new_contiguity = True )
278
283
279
284
@@ -354,8 +359,6 @@ def translate_bound_symbol(bsym: BoundSymbol) -> Any:
354
359
nvout = lc_to_nv_map [out ]
355
360
fd .add_output (nvout )
356
361
357
- MAX_LENGTH = 9999
358
-
359
362
if any (isinstance (t , DTensorProxy ) for t in sorted_unique_inputs ):
360
363
# multi-GPU path
361
364
utils .check (
@@ -375,19 +378,15 @@ def check_dtensor_tracing_and_runtime_metadata(inp):
375
378
lambda : "nvfuser: Expected runtime and tracing metadata to be the same for DTensor." ,
376
379
)
377
380
378
- fd = DirectFusionDefinition ()
381
+ fd = FusionDefinition () if nvfuser_version () >= DIRECT_BINDINGS_SUPPORTED_VERSION else DirectFusionDefinition ()
379
382
# Device may be set in one of the "factory" methods like full, iota, or uniform
380
383
# NOTE: This should be called before defining because a factory method may look-up at `_selected_device` while being defined.
381
384
fd ._selected_device = None
382
385
with fd :
383
386
definition (fd )
384
387
multidevice_schedule (fd , sorted_unique_inputs )
385
388
else :
386
- # NOTE nvFuser's default max length is 1024 operations at the time of this writing
387
- # This arbitrarily increases it to 9999
388
- # TODO Review splititng very large fusions or removing the max length restriction completely
389
- # See "Very large nvFuser fusions hit max_length"
390
- fd = FusionDefinition (max_length = MAX_LENGTH )
389
+ fd = FusionDefinition ()
391
390
# Device may be set in one of the "factory" methods like full, iota, or uniform
392
391
# NOTE: This should be called before defining because a factory method may look-up at `_selected_device` while being defined.
393
392
fd ._selected_device = None
@@ -539,7 +538,7 @@ def __call__(self, *args):
539
538
540
539
if dist .is_available () and any (isinstance (t , torch .distributed .tensor .DTensor ) for t in args ):
541
540
with annotate_for_profile (self .name ):
542
- output = nvfd . execute_with_dtensors (fd , args )
541
+ output = execute_with_dtensors (fd , args )
543
542
return output
544
543
else :
545
544
with annotate_for_profile (self .name ):
@@ -634,42 +633,6 @@ def __init__(self):
634
633
else :
635
634
self .set_fuel (FUEL_LEVEL .UNLIMITED )
636
635
637
- env_var_save_serde = os .getenv ("ENABLE_NVFUSER_SERIALIZATION" , None )
638
- save_serde : bool = env_var_save_serde in ("true" , "1" )
639
- self .write_cache_on_exit (save_serde )
640
-
641
- def write_cache_on_exit (self , save_cache : bool = False ):
642
- """
643
- Selects whether nvFuser writes its cache when the program exits.
644
-
645
- Args:
646
- save_cache (bool): A flag that enables saving nvFuser cache.
647
- Defaults to False.
648
-
649
- nvFuser's serialization will save the FusionCache data structure and any
650
- CUDA cubins into a FlatBuffer binary upon exiting the python program.
651
- The binary is stored in /tmp/nvfuser_kernel_db/ with the filename
652
- nvf_serde_[local_rank]_[cuda_major]_[cuda_minor]_[nvrtc_major]_[nvrtc_minor].
653
-
654
- Details:
655
- * If the common workspace is exists, nvFuser will load it automatically
656
- when the FusionCache is constructed.
657
- * When this function is enabled, then when the program exits NvFuser
658
- will save the FusionCache, overwritting the previous common workspace.
659
- * If this function is disabled, then when the program exits NvFuser
660
- does nothing. The previous common workspace is preserved if it exists.
661
- * If there are any issues when loading the serialized binary, it is
662
- deleted and the FusionCache is created with its default constructor.
663
- * When the LOCAL_RANK environment variable is set for ddp or fsdp, a
664
- separate fusion cache is saved for each device.
665
- """
666
- from nvfuser import enable_automatic_serialization , disable_automatic_serialization
667
-
668
- if save_cache :
669
- enable_automatic_serialization ()
670
- else :
671
- disable_automatic_serialization ()
672
-
673
636
def get_fuel (self , amount : int = 1 , / ) -> bool :
674
637
if self ._optimization_fuel is FUEL_LEVEL .UNLIMITED :
675
638
return True
0 commit comments