@@ -333,7 +333,7 @@ def TTG_GlobalScratchAllocOp : TTG_Op<"global_scratch_alloc", [MemoryEffects<[Me
333333}
334334
335335def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
336- RecursiveMemoryEffects, RecursivelySpeculatable,
336+ RecursiveMemoryEffects, RecursivelySpeculatable, AsyncRegions,
337337 DeclareOpInterfaceMethods<RegionBranchOpInterface>
338338]> {
339339 let summary = "asynchronously execute code on multiple warpgroups";
@@ -362,21 +362,24 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
362362 }
363363 partition0(%arg0: i32, %arg1: i32) num_warps(8) {
364364 some_async_dispatch(%arg0, %arg1)
365+ ttg.warp_return
365366 }
366367 partition1(%arg0: i32, %arg1: i32) num_warps(1) {
367368 some_async_dispatch(%arg0, %arg1)
369+ ttg.warp_return
368370 } : (i32, i32) -> i32
369371 ```
370372 }];
371373
372374 let arguments = (ins
373375 Variadic<AnyType>:$explicitCaptures,
374- DenseI32ArrayAttr:$partitionNumWarps
376+ DenseI32ArrayAttr:$partitionNumWarps,
377+ OptionalAttr<DenseI32ArrayAttr>:$warpGroupStartIds
375378 );
376379 let results = (outs Variadic<AnyType>:$defaultPassthrough);
377380
378381 let regions = (region
379- SizedRegion <1>:$defaultRegion,
382+ MinSizedRegion <1>:$defaultRegion,
380383 SizedRegion<1>:$partitionOpHolder
381384 );
382385
@@ -390,20 +393,19 @@ def TTG_WarpSpecializeOp : TTG_Op<"warp_specialize", [
390393
391394def TTG_WarpSpecializePartitionsOp : TTG_Op<"warp_specialize.partitions", [
392395 IsolatedFromAbove, RecursiveMemoryEffects, RecursivelySpeculatable,
393- Terminator, HasParent<"WarpSpecializeOp">,
394- SingleBlockImplicitTerminator<"WarpReturnOp">
396+ Terminator, HasParent<"WarpSpecializeOp">
395397]> {
396398 let summary = "container op for `ttg.warp_specialize`";
397399 let description = [{
398400 Because MLIR requires entire operations be isolated from above, this op
399401 contains the actual isolated from above regions of `ttg.warp_specialize`.
400402 }];
401403
402- let regions = (region VariadicRegion<SizedRegion <1>>:$partitionRegions);
404+ let regions = (region VariadicRegion<MinSizedRegion <1>>:$partitionRegions);
403405}
404406
405407def TTG_WarpYieldOp : TTG_Op<"warp_yield", [
406- Pure, Terminator, HasParent<"WarpSpecializeOp">,
408+ Pure, Terminator, ReturnLike, HasParent<"WarpSpecializeOp">,
407409 DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>
408410]> {
409411 let summary = "yield from the default region of `ttg.warp_specialize`";
@@ -422,6 +424,7 @@ def TTG_WarpYieldOp : TTG_Op<"warp_yield", [
422424 let arguments = (ins Variadic<AnyType>:$values);
423425
424426 let assemblyFormat = "($values^)? attr-dict (`:` type($values)^)?";
427+ let hasVerifier = 1;
425428}
426429
427430def TTG_WarpReturnOp : TTG_Op<"warp_return", [
0 commit comments