diff --git a/.github/pins/pytorch-upstream.txt b/.github/pins/pytorch-upstream.txt index aa41aecae0..18f3d3ae5b 100644 --- a/.github/pins/pytorch-upstream.txt +++ b/.github/pins/pytorch-upstream.txt @@ -1 +1 @@ -d0fd42eb3ac6939e63879bc055b6901103f713d3 +61dc5e9c0a36d590adc47b4110efd94d9eb59306 diff --git a/python/tutorials/08-grouped-gemm.py b/python/tutorials/08-grouped-gemm.py index 6f55fd9dcc..aec67565dc 100644 --- a/python/tutorials/08-grouped-gemm.py +++ b/python/tutorials/08-grouped-gemm.py @@ -171,9 +171,9 @@ def group_gemm_fn(group_A, group_B): g_lds += [A.stride(0), B.stride(0), C.stride(0)] # note these are device tensors - d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) - d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) - d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE, dtype=torch.uint64) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE, dtype=torch.uint64) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE, dtype=torch.uint64) d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) # we use a fixed number of CTA, and it's auto-tunable @@ -277,9 +277,9 @@ def benchmark(N, provider): g_sizes += [N, N, N] g_lds += [N, N, N] - d_a_ptrs = torch.tensor(A_addrs, device=DEVICE) - d_b_ptrs = torch.tensor(B_addrs, device=DEVICE) - d_c_ptrs = torch.tensor(C_addrs, device=DEVICE) + d_a_ptrs = torch.tensor(A_addrs, device=DEVICE, dtype=torch.uint64) + d_b_ptrs = torch.tensor(B_addrs, device=DEVICE, dtype=torch.uint64) + d_c_ptrs = torch.tensor(C_addrs, device=DEVICE, dtype=torch.uint64) d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=DEVICE) d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=DEVICE) diff --git a/scripts/patch-pytorch.sh b/scripts/patch-pytorch.sh index c9dbf931ca..5e35d25441 100755 --- a/scripts/patch-pytorch.sh +++ b/scripts/patch-pytorch.sh @@ -16,4 +16,3 @@ echo "Applying PyTorch patches in $REPO_ROOT" cd "$REPO_ROOT" curl -sSL https://github.com/pytorch/pytorch/pull/126516.diff | git apply - -curl -sSL https://github.com/pytorch/pytorch/pull/126456.diff | git apply -