@@ -1097,6 +1097,10 @@ def GPU_YieldOp : GPU_Op<"yield", [Pure, ReturnLike, Terminator]>,
10971097 ```
10981098 }];
10991099
1100+ let builders = [
1101+ OpBuilder<(ins), [{ /* nothing to do */ }]>
1102+ ];
1103+
11001104 let assemblyFormat = "attr-dict ($values^ `:` type($values))?";
11011105}
11021106
@@ -2921,4 +2925,138 @@ def GPU_SetCsrPointersOp : GPU_Op<"set_csr_pointers", [GPU_AsyncOpInterface]> {
29212925 }];
29222926}
29232927
2928+ def GPU_WarpExecuteOnLane0Op : GPU_Op<"warp_execute_on_lane_0",
2929+ [DeclareOpInterfaceMethods<RegionBranchOpInterface, ["areTypesCompatible"]>,
2930+ SingleBlockImplicitTerminator<"gpu::YieldOp">,
2931+ RecursiveMemoryEffects]> {
2932+ let summary = "Executes operations in the associated region on thread #0 of a"
2933+ "SPMD program";
2934+ let description = [{
2935+ `warp_execute_on_lane_0` is an operation used to bridge the gap between
2936+ vector programming and SPMD programming model like GPU SIMT. It allows to
2937+ trivially convert a region of vector code meant to run on a multiple threads
2938+ into a valid SPMD region and then allows incremental transformation to
2939+ distribute vector operations on the threads.
2940+
2941+ Any code present in the region would only be executed on first thread/lane
2942+ based on the `laneid` operand. The `laneid` operand is an integer ID between
2943+ [0, `warp_size`). The `warp_size` attribute indicates the number of lanes in
2944+ a warp.
2945+
2946+ Operands are vector values distributed on all lanes that may be used by
2947+ the single lane execution. The matching region argument is a vector of all
2948+ the values of those lanes available to the single active lane. The
2949+ distributed dimension is implicit based on the shape of the operand and
2950+ argument. the properties of the distribution may be described by extra
2951+ attributes (e.g. affine map).
2952+
2953+ Return values are distributed on all lanes using laneId as index. The
2954+ vector is distributed based on the shape ratio between the vector type of
2955+ the yield and the result type.
2956+ If the shapes are the same this means the value is broadcasted to all lanes.
2957+ In the future the distribution can be made more explicit using affine_maps
2958+ and will support having multiple Ids.
2959+
2960+ Therefore the `warp_execute_on_lane_0` operations allow to implicitly copy
2961+ between lane0 and the lanes of the warp. When distributing a vector
2962+ from lane0 to all the lanes, the data are distributed in a block cyclic way.
2963+ For example `vector<64xf32>` gets distributed on 32 threads and map to
2964+ `vector<2xf32>` where thread 0 contains vector[0] and vector[1].
2965+
2966+ During lowering values passed as operands and return value need to be
2967+ visible to different lanes within the warp. This would usually be done by
2968+ going through memory.
2969+
2970+ The region is *not* isolated from above. For values coming from the parent
2971+ region not going through operands only the lane 0 value will be accesible so
2972+ it generally only make sense for uniform values.
2973+
2974+ Example:
2975+ ```
2976+ // Execute in parallel on all threads/lanes.
2977+ gpu.warp_execute_on_lane_0 (%laneid)[32] {
2978+ // Serial code running only on thread/lane 0.
2979+ ...
2980+ }
2981+ // Execute in parallel on all threads/lanes.
2982+ ```
2983+
2984+ This may be lowered to an scf.if region as below:
2985+ ```
2986+ // Execute in parallel on all threads/lanes.
2987+ %cnd = arith.cmpi eq, %laneid, %c0 : index
2988+ scf.if %cnd {
2989+ // Serial code running only on thread/lane 0.
2990+ ...
2991+ }
2992+ // Execute in parallel on all threads/lanes.
2993+ ```
2994+
2995+ When the region has operands and/or return values:
2996+ ```
2997+ // Execute in parallel on all threads/lanes.
2998+ %0 = gpu.warp_execute_on_lane_0(%laneid)[32]
2999+ args(%v0 : vector<4xi32>) -> (vector<1xf32>) {
3000+ ^bb0(%arg0 : vector<128xi32>) :
3001+ // Serial code running only on thread/lane 0.
3002+ ...
3003+ gpu.yield %1 : vector<32xf32>
3004+ }
3005+ // Execute in parallel on all threads/lanes.
3006+ ```
3007+
3008+ values at the region boundary would go through memory:
3009+ ```
3010+ // Execute in parallel on all threads/lanes.
3011+ ...
3012+ // Store the data from each thread into memory and Synchronization.
3013+ %tmp0 = memreg.alloc() : memref<128xf32>
3014+ %tmp1 = memreg.alloc() : memref<32xf32>
3015+ %cnd = arith.cmpi eq, %laneid, %c0 : index
3016+ vector.store %v0, %tmp0[%laneid] : memref<128xf32>, vector<4xf32>
3017+ some_synchronization_primitive
3018+ scf.if %cnd {
3019+ // Serialized code running only on thread 0.
3020+ // Load the data from all the threads into a register from thread 0. This
3021+ // allow threads 0 to access data from all the threads.
3022+ %arg0 = vector.load %tmp0[%c0] : memref<128xf32>, vector<128xf32>
3023+ ...
3024+ // Store the data from thread 0 into memory.
3025+ vector.store %1, %tmp1[%c0] : memref<32xf32>, vector<32xf32>
3026+ }
3027+ // Synchronization and load the data in a block cyclic way so that the
3028+ // vector is distributed on all threads.
3029+ some_synchronization_primitive
3030+ %0 = vector.load %tmp1[%laneid] : memref<32xf32>, vector<32xf32>
3031+ // Execute in parallel on all threads/lanes.
3032+ ```
3033+
3034+ }];
3035+
3036+ let hasVerifier = 1;
3037+ let hasCustomAssemblyFormat = 1;
3038+ let arguments = (ins Index:$laneid, I64Attr:$warp_size,
3039+ Variadic<AnyType>:$args);
3040+ let results = (outs Variadic<AnyType>:$results);
3041+ let regions = (region SizedRegion<1>:$warpRegion);
3042+
3043+ let skipDefaultBuilders = 1;
3044+ let builders = [
3045+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
3046+ "int64_t":$warpSize)>,
3047+ // `blockArgTypes` are different than `args` types as they are they
3048+ // represent all the `args` instances visibile to lane 0. Therefore we need
3049+ // to explicit pass the type.
3050+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$laneid,
3051+ "int64_t":$warpSize, "ValueRange":$args,
3052+ "TypeRange":$blockArgTypes)>
3053+ ];
3054+
3055+ let extraClassDeclaration = [{
3056+ bool isDefinedOutsideOfRegion(Value value) {
3057+ return !getRegion().isAncestor(value.getParentRegion());
3058+ }
3059+ }];
3060+ }
3061+
29243062#endif // GPU_OPS
0 commit comments