@@ -288,37 +288,6 @@ void AffixTuningParameters::affixTuningParametersImpl(
288288 }
289289}
290290
291- static RockAccelTuningParamAttrInterface
292- deriveGemm1TuningParams (OpBuilder &builder, RockGemmGemmWrapperInterface op,
293- AttnPerfConfigAttr attnPerfConfig) {
294- auto gemm0TuningParams =
295- cast<RockAccelTuningParamAttrInterface>(op.getGemm0Params ().value ());
296- int64_t gemm1KPack = gemm0TuningParams.getKpack ();
297- if (auto gemm0XdlDerivedParams =
298- dyn_cast<MfmaGemmParamsAttr>(op.getGemm0Params ().value ())) {
299- return MfmaGemmParamsAttr::get (
300- builder.getContext (), gemm0TuningParams.getMPerBlock () / gemm1KPack,
301- attnPerfConfig.getMPerBlockG1 (), gemm0XdlDerivedParams.getNPerBlock (),
302- gemm0TuningParams.getKpack (),
303- gemm0TuningParams.getMPerWave () * (attnPerfConfig.getMPerBlockG1 () /
304- gemm0TuningParams.getMPerBlock ()),
305- gemm0XdlDerivedParams.getNPerWave (),
306- gemm0XdlDerivedParams.getMnPerXdl (), attnPerfConfig.getSplitKFactor (),
307- gemm0XdlDerivedParams.getScheduleVersion (),
308- gemm0XdlDerivedParams.getOutputSwizzle (),
309- gemm0XdlDerivedParams.getForceUnroll ());
310- }
311- return WmmaGemmParamsAttr::get (
312- builder.getContext (), gemm0TuningParams.getMPerBlock () / gemm1KPack,
313- attnPerfConfig.getMPerBlockG1 (), attnPerfConfig.getNPerBlockG0 (),
314- gemm0TuningParams.getKpack (),
315- gemm0TuningParams.getMPerWave () *
316- (attnPerfConfig.getMPerBlockG1 () / gemm0TuningParams.getMPerBlock ()),
317- gemm0TuningParams.getNPerWave (), gemm0TuningParams.getMnPerXdl (),
318- attnPerfConfig.getSplitKFactor (), gemm0TuningParams.getScheduleVersion (),
319- gemm0TuningParams.getOutputSwizzle (), gemm0TuningParams.getForceUnroll ());
320- }
321-
322291void AffixTuningParameters::affixTuningParametersImpl (
323292 RockGemmGemmWrapperInterface op) {
324293 OpBuilder builder (op.getContext ());
@@ -364,55 +333,22 @@ void AffixTuningParameters::affixTuningParametersImpl(
364333 return signalPassFailure ();
365334 }
366335
367- GemmFeatures features = rock::getFeatures (op);
368- RockAccelTuningParamAttrInterface accelParams0;
369- if (bitEnumContainsAny (features, GemmFeatures::mfma)) {
370- accelParams0 = MfmaGemmParamsAttr::get (
371- builder.getContext (), attnPerfConfig.getKpackPerBlock (),
372- attnPerfConfig.getMPerBlockG0 (), attnPerfConfig.getNPerBlockG0 (),
373- attnPerfConfig.getKpack (), attnPerfConfig.getMPerWave (),
374- attnPerfConfig.getNPerWave (), attnPerfConfig.getMnPerXdl (), 1 ,
375- attnPerfConfig.getScheduleVersion (), attnPerfConfig.getOutputSwizzle (),
376- attnPerfConfig.getForceUnroll ());
377- } else {
378- accelParams0 = WmmaGemmParamsAttr::get (
379- builder.getContext (), attnPerfConfig.getKpackPerBlock (),
380- attnPerfConfig.getMPerBlockG0 (), attnPerfConfig.getNPerBlockG0 (),
381- attnPerfConfig.getKpack (), attnPerfConfig.getMPerWave (),
382- attnPerfConfig.getNPerWave (), attnPerfConfig.getMnPerXdl (), 1 ,
383- attnPerfConfig.getScheduleVersion (), attnPerfConfig.getOutputSwizzle (),
384- attnPerfConfig.getForceUnroll ());
385- }
386- op.setGemm0ParamsAttr (accelParams0);
387- if (attnPerfConfig.getMPerBlockG0 () > attnPerfConfig.getMPerBlockG1 ()) {
388- op.emitError (
389- " The MPerBlockG0 should be larger or equal to getMPerBlockG1." );
336+ auto accelParams = getAttentionTuningParams (builder, op, attnPerfConfig);
337+ if (failed (accelParams)) {
338+ op.emitError (" The provided perf config is not valid" );
390339 return signalPassFailure ();
391340 }
392- RockAccelTuningParamAttrInterface accelParams1 =
393- deriveGemm1TuningParams (builder, op, attnPerfConfig);
341+ RockAccelTuningParamAttrInterface accelParams0, accelParams1;
342+ accelParams0 = accelParams->first ;
343+ accelParams1 = accelParams->second ;
344+ LLVM_DEBUG (llvm::dbgs () << " accelParams0=" << accelParams0 << " \n " );
345+ LLVM_DEBUG (llvm::dbgs () << " accelParams1=" << accelParams1 << " \n " );
346+ op.setGemm0ParamsAttr (accelParams0);
394347 op.setGemm1ParamsAttr (accelParams1);
395348 int64_t waveSize = rock::lookupArchInfo (rock::getArchValue (op)).waveSize ;
396349 int64_t blockSize = waveSize * accelParams0.getNPerBlock () *
397350 accelParams0.getMPerBlock () /
398351 (accelParams0.getMPerWave () * accelParams0.getNPerWave ());
399- auto populateParamsAccelPtr = PopulateParamsAccel::select (features);
400- LLVM_DEBUG (llvm::dbgs () << " accelParams0=" << accelParams0 << " \n " );
401- LLVM_DEBUG (llvm::dbgs () << " accelParams1=" << accelParams1 << " \n " );
402- LogicalResult isValidBlockwiseGemm0 =
403- populateParamsAccelPtr->isValidBlockwiseGemm (
404- accelParams0, cast<MemRefType>(op.getAType ()).getElementType (),
405- cast<MemRefType>(op.getBType ()).getElementType (),
406- rock::getArchValue (op));
407- LogicalResult isValidBlockwiseGemm1 =
408- populateParamsAccelPtr->isValidBlockwiseGemm (
409- accelParams1, cast<MemRefType>(op.getCType ()).getElementType (),
410- cast<MemRefType>(op.getCType ()).getElementType (),
411- rock::getArchValue (op));
412- if (isValidBlockwiseGemm0.failed () || isValidBlockwiseGemm1.failed ()) {
413- op.emitError (" The provided perf config is not valid" );
414- return signalPassFailure ();
415- }
416352
417353 IntegerAttr blockSizeAttr = builder.getI32IntegerAttr (blockSize);
418354 func::FuncOp funcOp = getOperation ();
0 commit comments