Skip to content

Commit 126969f

Browse files
committed
fixes
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent bdd3bad commit 126969f

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

python/tutorials/01-vector-add.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
6262
def add(x: torch.Tensor, y: torch.Tensor):
6363
# We need to preallocate the output.
6464
output = torch.empty_like(x)
65-
is_dvc = f"is_{DEVICE}"
66-
assert getattr(x, is_dvc) and getattr(y, is_dvc) and getattr(output, is_dvc)
65+
assert x.device.type == DEVICE and y.device.type == DEVICE and output.device.type == DEVICE
6766
n_elements = output.numel()
6867
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
6968
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].

python/tutorials/02-fused-softmax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
112112
# %%
113113
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
114114

115-
properties = driver.active.utils.get_device_properties(DEVICE)
115+
device = getattr(torch, DEVICE).current_device()
116+
properties = driver.active.utils.get_device_properties(device)
116117
NUM_SM = properties["multiprocessor_count"]
117118
SIZE_SMEM = properties["max_shared_mem"]
118119
WARPS_PER_EU = 8 # TODO: Get from properties

0 commit comments

Comments
 (0)