Skip to content

Commit 0b4feb7

Browse files
authored
[testing] moved di = torch._dynamo.device_interface into backend (#4818)
1 parent e7ec3fe commit 0b4feb7

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

python/triton/testing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from contextlib import contextmanager
66
from typing import Any, Dict, List
77
from . import language as tl
8+
from . import runtime
89

910

1011
def nvsmi(attrs):
@@ -114,7 +115,7 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu
114115
assert return_mode in ["min", "max", "mean", "median", "all"]
115116
import torch
116117

117-
di = torch._dynamo.device_interface.get_interface_for_device(device_type)
118+
di = runtime.driver.active.get_device_interface()
118119

119120
fn()
120121
di.synchronize()

third_party/amd/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,10 @@ def __init__(self):
484484
self.utils = HIPUtils()
485485
self.launcher_cls = HIPLauncher
486486

487+
def get_device_interface(self):
488+
import torch
489+
return torch.cuda
490+
487491
@staticmethod
488492
def is_active():
489493
import torch

third_party/nvidia/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,10 @@ def get_current_target(self):
440440
warp_size = 32
441441
return GPUTarget("cuda", capability, warp_size)
442442

443+
def get_device_interface(self):
444+
import torch
445+
return torch.cuda
446+
443447
@staticmethod
444448
def is_active():
445449
import torch

0 commit comments

Comments
 (0)