You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This change sets the launch grid attribute before calling
cuLaunchKernelEx.
This change is intended to pair with load/store atomics from
triton-lang/triton#5187
and is intended to add grid synchronization similar to what cooperative
groups do.
@ptillet Any recommendations on the UI for using this in code would be
most welcome :-)
- [X] I am not making a trivial change, such as fixing a typo in a
comment.
- [X] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.
- Select one of the following.
- [x] I have added tests.
- `/python/test` for end-to-end tests
- [?] This PR does not need a test because:
I am not entirely sure how to test the use of one driver API attr versus
another for this case yet.
I did add a test that exercises the launch_cooperative_grid=True launch
flag but I am not confirming that the plumbing triggers the use of the
API attr in test, although I did confirm it does offline using an
assert.
- Select one of the following.
- [X] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
and using the instructions it generates is not minimal.)
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
230
+
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int num_ctas, int launch_cooperative_grid, int clusterDimX, int clusterDimY, int clusterDimZ, int shared_memory, CUstream stream, CUfunction function, CUdeviceptr global_scratch{', '+arg_declsiflen(arg_decls) >0else''}) {{
231
231
void *params[] = {{ {', '.join(params)} }};
232
232
if (gridX*gridY*gridZ > 0) {{
233
-
if (num_ctas == 1) {{
233
+
if ((num_ctas == 1) && (0 == launch_cooperative_grid)) {{
0 commit comments