-
Notifications
You must be signed in to change notification settings - Fork 37
[WIP] [StreamK] #782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main_perf
Are you sure you want to change the base?
[WIP] [StreamK] #782
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,25 @@ | |
| torch.manual_seed(123) | ||
| random.seed(123) | ||
|
|
||
| total_sm = 304 | ||
|
|
||
| def is_hip_cdna3(): | ||
| target = triton.runtime.driver.active.get_current_target() | ||
| return target.backend == 'hip' and target.arch == 'gfx942' | ||
|
|
||
|
|
||
| def is_hip_cdna4(): | ||
| target = triton.runtime.driver.active.get_current_target() | ||
| return target.backend == 'hip' and target.arch == 'gfx950' | ||
|
|
||
|
|
||
| if is_hip_cdna3(): | ||
| total_sm = 304 | ||
| elif is_hip_cdna4(): | ||
| total_sm = 256 | ||
| else: | ||
| print("Unknown target") | ||
| exit(0) | ||
|
|
||
| print(f"total SMs: {total_sm}") | ||
|
|
||
|
|
||
|
|
@@ -138,7 +156,8 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch. | |
|
|
||
| ## test for tiles that is not multipe of 304 tiles | ||
| #m, n, k = 4096, 4096, 8192 # some problem size to test | ||
| m, n, k = 8192, 8192, 8192 # some problem size to test | ||
| m, n, k = 8192, 8192, 512 # some problem size to test | ||
| #m, n, k = 8704, 8704, 8192 # some problem size to test | ||
| #m, n, k = 512, 512, 512 # some problem size to test | ||
|
|
||
| ## memory bound sizes | ||
|
|
@@ -171,7 +190,14 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch. | |
| num_warps = 8 | ||
| waves_per_eu = 0 | ||
| mfmaInstrSize = 16 | ||
| kpack = 2 | ||
|
|
||
| if is_hip_cdna3(): | ||
| kpack = 2 | ||
| elif is_hip_cdna4(): | ||
| kpack = 1 | ||
| else: | ||
| print("Unknown target") | ||
| exit(0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same questions for MI308 |
||
|
|
||
| ##for total_sm in range(1, 305): | ||
| ## print(f"{total_sm=}") | ||
|
|
@@ -195,7 +221,7 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, bias: torch. | |
| P = torch.zeros((total_sm, BLK_M * BLK_N), device="cuda", dtype=torch.float32) | ||
| C = matmul.apply(A, B, C, bias, P, locks, total_sm, BLK_M, BLK_N, BLK_K, gsize_m, two_tiles, num_stages, num_warps, | ||
| waves_per_eu, mfmaInstrSize, kpack) | ||
| #exit(0) | ||
| exit(0) | ||
| matmul.set_debug(False) | ||
| expected = A @ B | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,7 +48,7 @@ def streamk_gemm( | |
|
|
||
| acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 | ||
|
|
||
| for tile_id in range(pid, total_full_tiles, NUM_SMS): | ||
| for tile_id in tl.range(pid, total_full_tiles, NUM_SMS, flatten=True): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have you observed any perf improvement by enable this loop fusion ?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, this is just an experiment and it doesn't compile. I filed a ticket for it https://github.com/ROCm/triton-internal/issues/784 |
||
| num_pid_in_group = GROUP_SIZE_M * num_pid_n | ||
| group_id = tile_id // num_pid_in_group | ||
| first_pid_m = group_id * GROUP_SIZE_M | ||
|
|
@@ -74,8 +74,10 @@ def streamk_gemm( | |
| if not EVEN_K: | ||
| loop_k -= 1 | ||
|
|
||
| tl.assume(loop_k > 1) | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice catch! |
||
| acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) | ||
| for k in range(0, loop_k): | ||
| for k in tl.range(0, loop_k): | ||
| a = tl.load(tl.multiple_of(A_BASE, (1, 16))) | ||
| b = tl.load(tl.multiple_of(B_BASE, (16, 1))) | ||
| acc += tl.dot(a, b) | ||
|
|
@@ -105,6 +107,9 @@ def streamk_gemm( | |
| mask = (rm < M)[:, None] & (rn < N)[None, :] | ||
| tl.store(C_, c, mask=mask) | ||
|
|
||
| if STREAMK_TILES == 0: | ||
| return | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in fact, I was trying to use STREAK_TILES to avoid this "if" with line 113 and 115, otherwise, we could use full tiles and total tiles for this if ?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In theory, when STREAM_TILES==0, start_iter == last_iter and we should not need the while loop. However, they are not constant, so compiler cannot optimize them out. I found this |
||
| tl.assume(pid >= 0) | ||
| total_streamk_iters = STREAMK_TILES * iters_per_tile | ||
| streamk_iters_pcu = total_streamk_iters // NUM_SMS | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how do we deal with MI308 80/64 CUs ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
idk. There must be a way (hopefully a torch API) to query the number of cus from the GPU. Let me find it.