@@ -129,32 +129,28 @@ void PreprocessMetadataBase::visit(Module *M) {
129129 // of ExecutionMode instructions.
130130
131131 // !{void (i32 addrspace(1)*)* @kernel, i32 17, i32 X, i32 Y, i32 Z}
132- if (MDNode *WGSize = Kernel.getMetadata (kSPIR2MD ::WGSize)) {
133- assert (WGSize->getNumOperands () >= 1 && WGSize->getNumOperands () <= 3 &&
134- " reqd_work_group_size does not have between 1 and 3 operands." );
135- SmallVector<unsigned , 3 > DecodedVals = decodeMDNode (WGSize);
136- EM.addOp ()
137- .add (&Kernel)
138- .add (spv::ExecutionModeLocalSize)
139- .add (DecodedVals[0 ])
140- .add (DecodedVals.size () >= 2 ? DecodedVals[1 ] : 1 )
141- .add (DecodedVals.size () == 3 ? DecodedVals[2 ] : 1 )
142- .done ();
143- }
144-
145132 // !{void (i32 addrspace(1)*)* @kernel, i32 18, i32 X, i32 Y, i32 Z}
146- if (MDNode *WGSizeHint = Kernel.getMetadata (kSPIR2MD ::WGSizeHint)) {
147- assert (WGSizeHint->getNumOperands () >= 1 &&
148- WGSizeHint->getNumOperands () <= 3 &&
149- " work_group_size_hint does not have between 1 and 3 operands." );
150- SmallVector<unsigned , 3 > DecodedVals = decodeMDNode (WGSizeHint);
151- EM.addOp ()
152- .add (&Kernel)
153- .add (spv::ExecutionModeLocalSizeHint)
154- .add (DecodedVals[0 ])
155- .add (DecodedVals.size () >= 2 ? DecodedVals[1 ] : 1 )
156- .add (DecodedVals.size () == 3 ? DecodedVals[2 ] : 1 )
157- .done ();
133+ // !{void (i32 addrspace(1)*)* @kernel, i32 max_work_group_size, i32 X,
134+ // i32 Y, i32 Z}
135+ std::pair<unsigned , const char *> WGSizeMDs[3 ] = {
136+ {spv::ExecutionModeLocalSize, kSPIR2MD ::WGSize},
137+ {spv::ExecutionModeLocalSizeHint, kSPIR2MD ::WGSizeHint},
138+ {spv::ExecutionModeMaxWorkgroupSizeINTEL, kSPIR2MD ::MaxWGSize},
139+ };
140+
141+ for (auto &[ExMode, MDName] : WGSizeMDs) {
142+ if (MDNode *WGMD = Kernel.getMetadata (MDName)) {
143+ assert (WGMD->getNumOperands () >= 1 && WGMD->getNumOperands () <= 3 &&
144+ " work-group metadata does not have between 1 and 3 operands." );
145+ SmallVector<unsigned , 3 > DecodedVals = decodeMDNode (WGMD);
146+ EM.addOp ()
147+ .add (&Kernel)
148+ .add (ExMode)
149+ .add (DecodedVals[0 ])
150+ .add (DecodedVals.size () >= 2 ? DecodedVals[1 ] : 1 )
151+ .add (DecodedVals.size () == 3 ? DecodedVals[2 ] : 1 )
152+ .done ();
153+ }
158154 }
159155
160156 // !{void (i32 addrspace(1)*)* @kernel, i32 30, i32 hint}
@@ -184,23 +180,6 @@ void PreprocessMetadataBase::visit(Module *M) {
184180 .done ();
185181 }
186182
187- // !{void (i32 addrspace(1)*)* @kernel, i32 max_work_group_size, i32 X,
188- // i32 Y, i32 Z}
189- if (MDNode *MaxWorkgroupSizeINTEL =
190- Kernel.getMetadata (kSPIR2MD ::MaxWGSize)) {
191- assert (MaxWorkgroupSizeINTEL->getNumOperands () == 3 &&
192- " max_work_group_size does not have 3 operands." );
193- SmallVector<unsigned , 3 > DecodedVals =
194- decodeMDNode (MaxWorkgroupSizeINTEL);
195- EM.addOp ()
196- .add (&Kernel)
197- .add (spv::ExecutionModeMaxWorkgroupSizeINTEL)
198- .add (DecodedVals[0 ])
199- .add (DecodedVals[1 ])
200- .add (DecodedVals[2 ])
201- .done ();
202- }
203-
204183 // !{void (i32 addrspace(1)*)* @kernel, i32 no_global_work_offset}
205184 if (Kernel.getMetadata (kSPIR2MD ::NoGlobalOffset)) {
206185 EM.addOp ().add (&Kernel).add (spv::ExecutionModeNoGlobalOffsetINTEL).done ();
0 commit comments