@@ -64,10 +64,16 @@ static void foldConstantBounds(
6464
6565static void applyBounds (FunctionOpInterface funcOp,
6666 ArrayRef<std::optional<int64_t >> workgroupSizes,
67- ArrayRef<std::optional<int64_t >> workgroupCounts) {
67+ ArrayRef<std::optional<int64_t >> workgroupCounts,
68+ std::optional<uint64_t > subgroupSize) {
6869 Builder b (funcOp->getContext ());
6970 funcOp->walk ([&](Operation *op) {
7071 TypeSwitch<Operation *>(op)
72+ .Case ([&](gpu::LaneIdOp laneIdOp) {
73+ if (subgroupSize) {
74+ laneIdOp.setUpperBoundAttr (b.getIndexAttr (*subgroupSize));
75+ }
76+ })
7177 .Case ([&](gpu::ThreadIdOp tidOp) {
7278 std::optional<int64_t > bound =
7379 workgroupSizes[static_cast <uint32_t >(tidOp.getDimension ())];
@@ -132,6 +138,8 @@ struct PropagateDispatchSizeBoundsPass final
132138 std::optional<SmallVector<int64_t >> staticWorkgroupSize =
133139 getWorkgroupSize (funcOp);
134140
141+ std::optional<uint64_t > subgroupSize = getGPUSubgroupSize (funcOp);
142+
135143 // Late in codegen, we've reconciled the workgroup size onto the export op.
136144 if (std::optional<IREE::HAL::ExecutableExportOp> exportOp =
137145 getEntryPoint (funcOp)) {
@@ -141,6 +149,11 @@ struct PropagateDispatchSizeBoundsPass final
141149 llvm::map_to_vector (exportWorkgroupSize->getAsRange <IntegerAttr>(),
142150 [](IntegerAttr a) { return a.getInt (); });
143151 }
152+
153+ if (std::optional<uint64_t > exportSubgroupSize =
154+ exportOp->getSubgroupSizeAsUInt ()) {
155+ subgroupSize = exportSubgroupSize;
156+ }
144157 }
145158
146159 if (staticWorkgroupSize) {
@@ -162,7 +175,7 @@ struct PropagateDispatchSizeBoundsPass final
162175 }
163176
164177 foldConstantBounds (funcOp, staticWorkgroupSize, staticWorkgroupCounts);
165- applyBounds (funcOp, workgroupSizes, workgroupCounts);
178+ applyBounds (funcOp, workgroupSizes, workgroupCounts, subgroupSize );
166179 }
167180};
168181} // namespace
0 commit comments