@@ -101,24 +101,30 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
101101 int64_t gemmNPerThread;
102102 uint32_t blockSize;
103103 int64_t splitKFactor;
104+ int64_t gemmScheduleVersion;
105+ int64_t outputSwizzle;
104106
105107 constexpr InitParamsNonAccel (uint32_t bSize, int64_t mPerBlock ,
106108 int64_t nPerBlock, int64_t kPerBlock ,
107109 int64_t mPerThread , int64_t nPerThread,
108- int64_t splitKFactor)
110+ int64_t splitKFactor, int64_t scheduleVersion,
111+ int64_t outputSwizzle)
109112 : InitParams{mPerBlock , nPerBlock, kPerBlock }, gemmMPerThread(mPerThread ),
110113 gemmNPerThread (nPerThread), blockSize(bSize),
111- splitKFactor (splitKFactor) {}
114+ splitKFactor (splitKFactor), gemmScheduleVersion(scheduleVersion),
115+ outputSwizzle (outputSwizzle) {}
112116
113117 constexpr InitParamsNonAccel ()
114- : InitParamsNonAccel(0U , 0LL , 0LL , 0LL , 0LL , 0LL , 1LL ) {}
118+ : InitParamsNonAccel(0U , 0LL , 0LL , 0LL , 0LL , 0LL , 1LL , 1LL , 2LL ) {}
115119
116120 InitParamsNonAccel (GeneralGemmParamsAttr attr)
117121 : InitParams{attr.getMPerBlock (), attr.getNPerBlock (),
118122 attr.getKPerBlock ()},
119123 gemmMPerThread (attr.getMPerThread()),
120124 gemmNPerThread (attr.getNPerThread()), blockSize(attr.getBlockSize()),
121- splitKFactor (attr.getSplitKFactor()){};
125+ splitKFactor (attr.getSplitKFactor()),
126+ gemmScheduleVersion (attr.getScheduleVersion()),
127+ outputSwizzle (attr.getOutputSwizzle()){};
122128
123129 int64_t getKPack () { return 1 ; }
124130
@@ -130,32 +136,41 @@ struct InitParamsNonAccel : InitParams, Serializable<InitParamsNonAccel> {
130136 f (self.gemmKPerBlock );
131137 f (self.gemmMPerThread );
132138 f (self.gemmNPerThread );
133- if (self.version ! = Version::V1 )
139+ if (self.version > = Version::V2 )
134140 f (self.splitKFactor );
141+ if (self.version >= Version::V3) {
142+ f (self.gemmScheduleVersion );
143+ f (self.outputSwizzle );
144+ }
135145 }
136146};
137147
138148struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
139149 constexpr InitParamsAccel (int64_t mPerBlock , int64_t nPerBlock,
140150 int64_t kPerBlock , int64_t mPerWave ,
141151 int64_t nPerWaveOrMnPerXdl, int64_t kPack ,
142- int64_t splitKFactor, bool aThreadCopyMoreGemmK,
152+ int64_t splitKFactor, int64_t scheduleVersion,
153+ int64_t outputSwizzle, bool aThreadCopyMoreGemmK,
143154 bool bThreadCopyMoreGemmKPack)
144155 : InitParams{mPerBlock , nPerBlock, kPerBlock }, gemmMPerWave(mPerWave ),
145156 gemmNPerWaveOrMnPerXdl (nPerWaveOrMnPerXdl), gemmKPack(kPack ),
146- splitKFactor (splitKFactor),
157+ splitKFactor (splitKFactor), gemmScheduleVersion(scheduleVersion),
158+ outputSwizzle (outputSwizzle),
147159 gemmAThreadCopyMoreGemmK (aThreadCopyMoreGemmK),
148160 gemmBThreadCopyMoreGemmKPack (bThreadCopyMoreGemmKPack) {}
149161
150162 constexpr InitParamsAccel ()
151- : InitParamsAccel(0LL , 0LL , 0LL , 0LL , 0LL , 0LL , 1LL , false , false ) {}
163+ : InitParamsAccel(0LL , 0LL , 0LL , 0LL , 0LL , 0LL , 1LL , 1LL , 2LL , false ,
164+ false ) {}
152165
153166 InitParamsAccel (XdlopsGemmParamsAttr attr)
154167 : InitParams{attr.getMPerBlock (), attr.getNPerBlock (),
155168 attr.getKpackPerBlock ()},
156169 gemmMPerWave (attr.getMPerWave()),
157170 gemmNPerWaveOrMnPerXdl (attr.getMnPerXdl()), gemmKPack(attr.getKpack()),
158171 splitKFactor (attr.getSplitKFactor()),
172+ gemmScheduleVersion (attr.getScheduleVersion()),
173+ outputSwizzle (attr.getOutputSwizzle()),
159174 gemmAThreadCopyMoreGemmK (attr.getForceUnroll()),
160175 gemmBThreadCopyMoreGemmKPack (false ){};
161176
@@ -165,6 +180,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
165180 gemmMPerWave (attr.getMPerWave()),
166181 gemmNPerWaveOrMnPerXdl (attr.getNPerWave()), gemmKPack(attr.getKpack()),
167182 splitKFactor (attr.getSplitKFactor()),
183+ gemmScheduleVersion (attr.getScheduleVersion()),
184+ outputSwizzle (attr.getOutputSwizzle()),
168185 gemmAThreadCopyMoreGemmK (attr.getForceUnroll()),
169186 gemmBThreadCopyMoreGemmKPack (false ){};
170187
@@ -174,6 +191,8 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
174191 int64_t gemmNPerWaveOrMnPerXdl;
175192 int64_t gemmKPack;
176193 int64_t splitKFactor;
194+ int64_t gemmScheduleVersion;
195+ int64_t outputSwizzle;
177196 bool gemmAThreadCopyMoreGemmK;
178197 bool gemmBThreadCopyMoreGemmKPack;
179198
@@ -185,9 +204,13 @@ struct InitParamsAccel : InitParams, Serializable<InitParamsAccel> {
185204 f (self.gemmMPerWave );
186205 f (self.gemmNPerWaveOrMnPerXdl );
187206 f (self.gemmKPack );
188- if (self.version ! = Version::V1 ) {
207+ if (self.version > = Version::V2 ) {
189208 f (self.splitKFactor );
190209 }
210+ if (self.version >= Version::V3) {
211+ f (self.gemmScheduleVersion );
212+ f (self.outputSwizzle );
213+ }
191214 f (self.gemmAThreadCopyMoreGemmK );
192215 f (self.gemmBThreadCopyMoreGemmKPack );
193216 }
@@ -206,14 +229,15 @@ class BasePopulateParams {
206229private:
207230 struct InitParamData {
208231 InitParamType paramSet;
209- size_t original_pos ;
210- int64_t padding_amount ;
232+ size_t originalPos ;
233+ int64_t paddingAmount ;
211234
212235 bool operator <(const InitParamData &rhs) const {
213- if (this ->padding_amount < rhs.padding_amount ) {
236+ if (this ->paddingAmount < rhs.paddingAmount ) {
214237 return true ;
215- } else if (this ->padding_amount == rhs.padding_amount ) {
216- return (this ->original_pos < rhs.original_pos );
238+ }
239+ if (this ->paddingAmount == rhs.paddingAmount ) {
240+ return (this ->originalPos < rhs.originalPos );
217241 }
218242 return false ;
219243 }
@@ -227,10 +251,10 @@ class BasePopulateParams {
227251 InitParamType paramSet = initParams[pos];
228252 InitParamData paramData;
229253 paramData.paramSet = paramSet;
230- paramData.original_pos = pos;
231- paramData.padding_amount = calculatePaddingAmount (paramSet, gemmSize);
232- assert (paramData.original_pos >= 0 );
233- assert (paramData.padding_amount >= 0 );
254+ paramData.originalPos = pos;
255+ paramData.paddingAmount = calculatePaddingAmount (paramSet, gemmSize);
256+ assert (paramData.originalPos >= 0 );
257+ assert (paramData.paddingAmount >= 0 );
234258 res.push_back (paramData);
235259 }
236260 return res;
0 commit comments