Skip to content

Commit 5e8e4bd

Browse files
committed
address review comments
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 1825764 commit 5e8e4bd

File tree

4 files changed

+5
-8
lines changed

4 files changed

+5
-8
lines changed

python/tutorials/01-vector-add.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector.
6262
def add(x: torch.Tensor, y: torch.Tensor):
6363
# We need to preallocate the output.
6464
output = torch.empty_like(x)
65-
is_dvc = f"is_{DEVICE}"
66-
assert getattr(x, is_dvc) and getattr(y, is_dvc) and getattr(output, is_dvc)
65+
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
6766
n_elements = output.numel()
6867
# The SPMD launch grid denotes the number of kernel instances that run in parallel.
6968
# It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int].

python/tutorials/02-fused-softmax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n
112112
# %%
113113
# We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.
114114

115-
device = getattr(torch, DEVICE).current_device()
116-
properties = driver.active.utils.get_device_properties(device)
115+
properties = driver.active.utils.get_device_properties(DEVICE)
117116
NUM_SM = properties["multiprocessor_count"]
118117
SIZE_SMEM = properties["max_shared_mem"]
119118
WARPS_PER_EU = 8 # TODO: Get from properties

python/tutorials/04-low-memory-dropout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ def dropout(x, x_keep, p):
7373

7474

7575
# Input tensor
76-
x = getattr(torch.randn(size=(10, )), DEVICE)()
76+
x = torch.randn(size=(10, ), device=DEVICE)
7777
# Dropout mask
7878
p = 0.5
79-
x_keep = getattr((torch.rand(size=(10, )) > p).to(torch.int32), DEVICE)()
79+
x_keep = (torch.rand(size=(10, ), device=DEVICE) > p).to(torch.int32)
8080
#
8181
output = dropout(x, x_keep=x_keep, p=p)
8282
print(tabulate.tabulate([
@@ -140,7 +140,7 @@ def seeded_dropout(x, p, seed):
140140
return output
141141

142142

143-
x = getattr(torch.randn(size=(10, )), DEVICE)()
143+
x = torch.randn(size=(10, ), device=DEVICE)
144144
# Compare this to the baseline - dropout mask is never instantiated!
145145
output = seeded_dropout(x, p=0.5, seed=123)
146146
output2 = seeded_dropout(x, p=0.5, seed=123)

python/tutorials/08-grouped-gemm.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def grouped_matmul_kernel(
147147

148148

149149
def group_gemm_fn(group_A, group_B):
150-
device = torch.device(DEVICE)
151150
assert len(group_A) == len(group_B)
152151
group_size = len(group_A)
153152

0 commit comments

Comments
 (0)