Skip to content

Commit 8b542cf

Browse files
authored
Enable direct bindings in Thunder (#2502)
1 parent 602fdbd commit 8b542cf

File tree

1 file changed

+18
-55
lines changed

1 file changed

+18
-55
lines changed

thunder/executors/nvfuserex_impl.py

Lines changed: 18 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,20 @@
6767
from thunder.executors.nvfuserex import nvfuser_version
6868

6969

70-
DTENSOR_SUPPORTED_VERSION = LooseVersion("0.2.28")
71-
if nvfuser_version() >= DTENSOR_SUPPORTED_VERSION:
72-
import nvfuser_direct as nvfd
73-
from nvfuser_direct import FusionDefinition as DirectFusionDefinition
74-
7570
# NOTE This impl file is here because nvFuser may not be available, so it's imported conditionally
7671
# by nvfuserex.py when nvFuser is available.
77-
import nvfuser
78-
from nvfuser import DataType, FusionDefinition
72+
73+
DIRECT_BINDINGS_SUPPORTED_VERSION = LooseVersion("0.2.34")
74+
DTENSOR_SUPPORTED_VERSION = LooseVersion("0.2.28")
75+
if nvfuser_version() >= DIRECT_BINDINGS_SUPPORTED_VERSION:
76+
import nvfuser_direct as nvfuser
77+
from nvfuser_direct import DataType, FusionDefinition, multidevice, ParallelType, execute_with_dtensors
78+
else:
79+
if nvfuser_version() >= DTENSOR_SUPPORTED_VERSION:
80+
from nvfuser_direct import FusionDefinition as DirectFusionDefinition
81+
from nvfuser_direct import multidevice, ParallelType, execute_with_dtensors
82+
import nvfuser
83+
from nvfuser import DataType, FusionDefinition
7984

8085
#
8186
# Helper functions
@@ -258,9 +263,9 @@ def multidevice_schedule(fd: FusionDefinition, in_dtensors: list[Proxy]) -> None
258263

259264
# nvfuser's DeviceMesh supports torch.Tensor since 0.2.30
260265
if nvfuser_version() >= LooseVersion("0.2.30"):
261-
mesh = nvfd.multidevice.DeviceMesh(in_dtensor.device_mesh.mesh)
266+
mesh = multidevice.DeviceMesh(in_dtensor.device_mesh.mesh)
262267
else:
263-
mesh = nvfd.multidevice.DeviceMesh(in_dtensor.device_mesh.mesh.tolist())
268+
mesh = multidevice.DeviceMesh(in_dtensor.device_mesh.mesh.tolist())
264269

265270
in_tv.set_device_mesh(mesh)
266271

@@ -273,7 +278,7 @@ def multidevice_schedule(fd: FusionDefinition, in_dtensors: list[Proxy]) -> None
273278
if placement.is_shard():
274279
dim = cast(Shard, placement).dim
275280
in_tv.split(dim, mesh.size, inner_split=False)
276-
in_tv.axis(dim).parallelize(nvfd.ParallelType.mesh_x)
281+
in_tv.axis(dim).parallelize(ParallelType.mesh_x)
277282
in_tv.set_allocation_domain(in_tv.get_loop_domain(), new_contiguity=True)
278283

279284

@@ -354,8 +359,6 @@ def translate_bound_symbol(bsym: BoundSymbol) -> Any:
354359
nvout = lc_to_nv_map[out]
355360
fd.add_output(nvout)
356361

357-
MAX_LENGTH = 9999
358-
359362
if any(isinstance(t, DTensorProxy) for t in sorted_unique_inputs):
360363
# multi-GPU path
361364
utils.check(
@@ -375,19 +378,15 @@ def check_dtensor_tracing_and_runtime_metadata(inp):
375378
lambda: "nvfuser: Expected runtime and tracing metadata to be the same for DTensor.",
376379
)
377380

378-
fd = DirectFusionDefinition()
381+
fd = FusionDefinition() if nvfuser_version() >= DIRECT_BINDINGS_SUPPORTED_VERSION else DirectFusionDefinition()
379382
# Device may be set in one of the "factory" methods like full, iota, or uniform
380383
# NOTE: This should be called before defining because a factory method may look-up at `_selected_device` while being defined.
381384
fd._selected_device = None
382385
with fd:
383386
definition(fd)
384387
multidevice_schedule(fd, sorted_unique_inputs)
385388
else:
386-
# NOTE nvFuser's default max length is 1024 operations at the time of this writing
387-
# This arbitrarily increases it to 9999
388-
# TODO Review splititng very large fusions or removing the max length restriction completely
389-
# See "Very large nvFuser fusions hit max_length"
390-
fd = FusionDefinition(max_length=MAX_LENGTH)
389+
fd = FusionDefinition()
391390
# Device may be set in one of the "factory" methods like full, iota, or uniform
392391
# NOTE: This should be called before defining because a factory method may look-up at `_selected_device` while being defined.
393392
fd._selected_device = None
@@ -539,7 +538,7 @@ def __call__(self, *args):
539538

540539
if dist.is_available() and any(isinstance(t, torch.distributed.tensor.DTensor) for t in args):
541540
with annotate_for_profile(self.name):
542-
output = nvfd.execute_with_dtensors(fd, args)
541+
output = execute_with_dtensors(fd, args)
543542
return output
544543
else:
545544
with annotate_for_profile(self.name):
@@ -634,42 +633,6 @@ def __init__(self):
634633
else:
635634
self.set_fuel(FUEL_LEVEL.UNLIMITED)
636635

637-
env_var_save_serde = os.getenv("ENABLE_NVFUSER_SERIALIZATION", None)
638-
save_serde: bool = env_var_save_serde in ("true", "1")
639-
self.write_cache_on_exit(save_serde)
640-
641-
def write_cache_on_exit(self, save_cache: bool = False):
642-
"""
643-
Selects whether nvFuser writes its cache when the program exits.
644-
645-
Args:
646-
save_cache (bool): A flag that enables saving nvFuser cache.
647-
Defaults to False.
648-
649-
nvFuser's serialization will save the FusionCache data structure and any
650-
CUDA cubins into a FlatBuffer binary upon exiting the python program.
651-
The binary is stored in /tmp/nvfuser_kernel_db/ with the filename
652-
nvf_serde_[local_rank]_[cuda_major]_[cuda_minor]_[nvrtc_major]_[nvrtc_minor].
653-
654-
Details:
655-
* If the common workspace is exists, nvFuser will load it automatically
656-
when the FusionCache is constructed.
657-
* When this function is enabled, then when the program exits NvFuser
658-
will save the FusionCache, overwritting the previous common workspace.
659-
* If this function is disabled, then when the program exits NvFuser
660-
does nothing. The previous common workspace is preserved if it exists.
661-
* If there are any issues when loading the serialized binary, it is
662-
deleted and the FusionCache is created with its default constructor.
663-
* When the LOCAL_RANK environment variable is set for ddp or fsdp, a
664-
separate fusion cache is saved for each device.
665-
"""
666-
from nvfuser import enable_automatic_serialization, disable_automatic_serialization
667-
668-
if save_cache:
669-
enable_automatic_serialization()
670-
else:
671-
disable_automatic_serialization()
672-
673636
def get_fuel(self, amount: int = 1, /) -> bool:
674637
if self._optimization_fuel is FUEL_LEVEL.UNLIMITED:
675638
return True

0 commit comments

Comments
 (0)