Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/pins/pytorch-upstream.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d0fd42eb3ac6939e63879bc055b6901103f713d3
61dc5e9c0a36d590adc47b4110efd94d9eb59306
12 changes: 6 additions & 6 deletions python/tutorials/08-grouped-gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def group_gemm_fn(group_A, group_B):
g_lds += [A.stride(0), B.stride(0), C.stride(0)]

# note these are device tensors
d_a_ptrs = torch.tensor(A_addrs, device=device)
d_b_ptrs = torch.tensor(B_addrs, device=device)
d_c_ptrs = torch.tensor(C_addrs, device=device)
d_a_ptrs = torch.tensor(A_addrs, device=device, dtype=torch.uint64)
d_b_ptrs = torch.tensor(B_addrs, device=device, dtype=torch.uint64)
d_c_ptrs = torch.tensor(C_addrs, device=device, dtype=torch.uint64)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device)
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device)
# we use a fixed number of CTA, and it's auto-tunable
Expand Down Expand Up @@ -276,9 +276,9 @@ def benchmark(N, provider):
g_sizes += [N, N, N]
g_lds += [N, N, N]

d_a_ptrs = torch.tensor(A_addrs, device="xpu")
d_b_ptrs = torch.tensor(B_addrs, device="xpu")
d_c_ptrs = torch.tensor(C_addrs, device="xpu")
d_a_ptrs = torch.tensor(A_addrs, device="xpu", dtype=torch.uint64)
d_b_ptrs = torch.tensor(B_addrs, device="xpu", dtype=torch.uint64)
d_c_ptrs = torch.tensor(C_addrs, device="xpu", dtype=torch.uint64)
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="xpu")
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="xpu")

Expand Down
1 change: 0 additions & 1 deletion scripts/patch-pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ echo "Applying PyTorch patches in $REPO_ROOT"
cd "$REPO_ROOT"

curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply -
curl -sSL https://github.com/pytorch/pytorch/pull/126456.diff | git apply -
Loading