Skip to content

Commit 3ccab57

Browse files
authored
Update PyTorch pin to the version which support elapsed_time; remove our patch for elapsed_time (#2952)
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 3f4fdd1 commit 3ccab57

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

.github/pins/pytorch-upstream.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d0fd42eb3ac6939e63879bc055b6901103f713d3
1+
61dc5e9c0a36d590adc47b4110efd94d9eb59306

python/tutorials/08-grouped-gemm.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,9 @@ def group_gemm_fn(group_A, group_B):
171171
g_lds += [A.stride(0), B.stride(0), C.stride(0)]
172172

173173
# note these are device tensors
174-
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
175-
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
176-
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
174+
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE, dtype=torch.uint64)
175+
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE, dtype=torch.uint64)
176+
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE, dtype=torch.uint64)
177177
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
178178
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
179179
# we use a fixed number of CTA, and it's auto-tunable
@@ -277,9 +277,9 @@ def benchmark(N, provider):
277277
g_sizes += [N, N, N]
278278
g_lds += [N, N, N]
279279

280-
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE)
281-
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE)
282-
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE)
280+
d_a_ptrs = torch.tensor(A_addrs, device=DEVICE, dtype=torch.uint64)
281+
d_b_ptrs = torch.tensor(B_addrs, device=DEVICE, dtype=torch.uint64)
282+
d_c_ptrs = torch.tensor(C_addrs, device=DEVICE, dtype=torch.uint64)
283283
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE)
284284
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE)
285285

scripts/patch-pytorch.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@ echo "Applying PyTorch patches in $REPO_ROOT"
1616
cd "$REPO_ROOT"
1717

1818
curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
19-
curl -sSL https://github.com/pytorch/pytorch/pull/126456.diff | git apply -

0 commit comments

Comments
 (0)