@@ -125,19 +125,13 @@ int32_t mini_jit::TensorOperation::findMatch(const std::span<const TensorConfig:
125125}
126126
127127bool mini_jit::TensorOperation::isValidPrimConfig (const std::span<const TensorConfig::dim_t > &dim,
128- const std::span<const TensorConfig::exec_t > &exec,
129- const std::span<const int64_t > &strides_in0, const std::span<const int64_t > &strides_out)
128+ const std::span<const TensorConfig::exec_t > &exec)
130129{
131130 int32_t indexM = findMatch (dim, exec, TensorConfig::dim_t ::m, TensorConfig::exec_t ::prim);
132131 int32_t indexN = findMatch (dim, exec, TensorConfig::dim_t ::n, TensorConfig::exec_t ::prim);
133132 if (indexM == -1 || indexN == -1 )
134133 {
135- std::cerr << " 1: Could not find a matching index: indexM:" << indexM << " , indexN:" << indexN << std::endl;
136- return false ;
137- }
138-
139- if (!(isExpectedStride (1 , indexM, strides_in0) && isExpectedStride (1 , indexM, strides_out)))
140- {
134+ std::cerr << " isValidPrimConfig 1: Could not find a matching index: indexM:" << indexM << " , indexN:" << indexN << std::endl;
141135 return false ;
142136 }
143137
@@ -146,16 +140,48 @@ bool mini_jit::TensorOperation::isValidPrimConfig(const std::span<const TensorCo
146140 indexN = findMatch (dim, exec, TensorConfig::dim_t ::n, TensorConfig::exec_t ::prim, indexN + 1 );
147141 if (indexM != -1 || indexN != -1 )
148142 {
149- std::cerr << " 2: Could not find a matching index: indexM:" << indexM << " , indexN" << indexN << std::endl;
143+ std::cerr << " isValidPrimConfig 2: Could not find a matching index: indexM:" << indexM << " , indexN" << indexN << std::endl;
150144 return false ;
151145 }
152146
153147 return true ;
154148}
155149
150+ bool mini_jit::TensorOperation::isValidPrimStrides (const std::span<const TensorConfig::dim_t > &dim,
151+ const std::span<const TensorConfig::exec_t > &exec,
152+ const std::span<const int64_t > &strides_in0, const std::span<const int64_t > &strides_out,
153+ const TensorConfig::prim_t main_prim)
154+ {
155+ int32_t indexM = findMatch (dim, exec, TensorConfig::dim_t ::m, TensorConfig::exec_t ::prim);
156+ int32_t indexN = findMatch (dim, exec, TensorConfig::dim_t ::n, TensorConfig::exec_t ::prim);
157+ if (indexM == -1 || indexN == -1 )
158+ {
159+ std::cerr << " isValidStride: Could not find a matching index: indexM:" << indexM << " , indexN:" << indexN << std::endl;
160+ return false ;
161+ }
162+
163+ // stride of m = 1 and stride of n = 1 i.e. no transpose
164+ if (isExpectedStride (1 , indexM, strides_in0) && isExpectedStride (1 , indexM, strides_out))
165+ {
166+ return true ;
167+ }
168+
169+ // Check transpose in unary op
170+ if (isUnary (main_prim) && isExpectedStride (1 , indexM, strides_in0) && isExpectedStride (1 , indexN, strides_out))
171+ {
172+ isTranspose = true ;
173+ return true ;
174+ }
175+
176+ std::cerr << " isValidStride: Could not find a valid stride: in0: m-stride: " << strides_in0[indexM]
177+ << " , n-stride: " << strides_in0[indexN] << " ; out: m-stride: " << strides_out[indexM] << " , n-stride: " << strides_out[indexN]
178+ << std::endl;
179+ return false ;
180+ }
181+
156182bool mini_jit::TensorOperation::isValidKDim (const std::span<const TensorConfig::dim_t > &dim,
157183 const std::span<const TensorConfig::exec_t > &exec, const std::span<const int64_t > &strides_in1,
158- TensorConfig::prim_t prim)
184+ const TensorConfig::prim_t prim)
159185{
160186 if (isBrgemm (prim))
161187 {
@@ -249,23 +275,77 @@ bool mini_jit::TensorOperation::isValidStride(const std::span<const TensorConfig
249275 switch (strideType)
250276 {
251277 case stride_t ::in0:
252- if (*iDim == TensorConfig:: dim_t ::n && *iStride != 0 )
278+ switch (*iDim)
253279 {
254- return false ;
280+ case TensorConfig::dim_t ::c:
281+ case TensorConfig::dim_t ::m:
282+ case TensorConfig::dim_t ::k:
283+ if (*iStride == 0 )
284+ {
285+ return false ;
286+ }
287+ break ;
288+
289+ case TensorConfig::dim_t ::n:
290+ if (*iStride != 0 )
291+ {
292+ return false ;
293+ }
294+ break ;
295+
296+ default :
297+ release_assert (false , " Found unhandled dimension type." );
298+ break ;
255299 }
256300 break ;
257301
258302 case stride_t ::in1:
259- if (*iDim == TensorConfig:: dim_t ::m && *iStride != 0 )
303+ switch (*iDim)
260304 {
261- return false ;
305+ case TensorConfig::dim_t ::c:
306+ case TensorConfig::dim_t ::n:
307+ case TensorConfig::dim_t ::k:
308+ if (*iStride == 0 )
309+ {
310+ return false ;
311+ }
312+ break ;
313+
314+ case TensorConfig::dim_t ::m:
315+ if (*iStride != 0 )
316+ {
317+ return false ;
318+ }
319+ break ;
320+
321+ default :
322+ release_assert (false , " Found unhandled dimension type." );
323+ break ;
262324 }
263325 break ;
264326
265327 case stride_t ::out:
266- if (*iDim == TensorConfig:: dim_t ::k && *iStride != 0 )
328+ switch (*iDim)
267329 {
268- return false ;
330+ case TensorConfig::dim_t ::c:
331+ case TensorConfig::dim_t ::n:
332+ case TensorConfig::dim_t ::m:
333+ if (*iStride == 0 )
334+ {
335+ return false ;
336+ }
337+ break ;
338+
339+ case TensorConfig::dim_t ::k:
340+ if (*iStride != 0 )
341+ {
342+ return false ;
343+ }
344+ break ;
345+
346+ default :
347+ release_assert (false , " Found unhandled dimension type." );
348+ break ;
269349 }
270350 break ;
271351
@@ -279,7 +359,7 @@ bool mini_jit::TensorOperation::isValidStride(const std::span<const TensorConfig
279359}
280360
281361mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary (Unary &unary, TensorConfig::prim_t prim,
282- const std::span<const int64_t > &dim_sizes)
362+ const std::span<const int64_t > &dim_sizes, bool isTranspose )
283363{
284364 release_assert (indexPrimM != -1 , " Expected a match for the m primitive dimension" );
285365 release_assert (indexPrimN != -1 , " Expected a match for the n primitive dimension" );
@@ -303,7 +383,8 @@ mini_jit::Unary::error_t mini_jit::TensorOperation::generateUnary(Unary &unary,
303383 release_assert (false , " Found a invalid type for the unary first touch." );
304384 break ;
305385 }
306- return unary.generate (dim_sizes[indexPrimM], dim_sizes[indexPrimN], 0 , Unary::dtype_t ::fp32, type);
386+
387+ return unary.generate (dim_sizes[indexPrimM], dim_sizes[indexPrimN], isTranspose, Unary::dtype_t ::fp32, type);
307388}
308389
309390mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup (const TensorConfig &config)
@@ -322,7 +403,10 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
322403 std::span<const TensorConfig::dim_t > dim_types, std::span<const TensorConfig::exec_t > exec_types, std::span<const int64_t > dim_sizes,
323404 std::span<const int64_t > strides_in0, std::span<const int64_t > strides_in1, std::span<const int64_t > strides_out)
324405{
406+ // Reset to defaults
325407 hasSetupError = false ;
408+ isParallel = false ;
409+ isTranspose = false ;
326410 indexPrimBatch = -1 ;
327411 indexPrimK = -1 ;
328412 indexPrimM = -1 ;
@@ -367,6 +451,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
367451 if (kDimExecType != -1 )
368452 {
369453 hasSetupError = true ;
454+ std::cerr << " Error: Found k dimension tagged as shared, but can not execute k dimension as shared." << std::endl;
370455 return error_t ::err_k_dimension_must_not_be_shared;
371456 }
372457 }
@@ -387,14 +472,21 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
387472 return error_t ::err_invalid_execution_order;
388473 }
389474
390- if (!isValidPrimConfig (dim_types, exec_types, strides_in0, strides_out ))
475+ if (!isValidPrimConfig (dim_types, exec_types))
391476 {
392477 hasSetupError = true ;
393478 std::cerr << " Error: Invalid primitive configuration detected. Expected one primitive for m and one primitive for n to exist"
394479 << std::endl;
395480 return error_t ::err_invalid_primitive_configuration;
396481 }
397482
483+ if (!isValidPrimStrides (dim_types, exec_types, strides_in0, strides_out, prim_main))
484+ {
485+ hasSetupError = true ;
486+ std::cerr << " Error: Invalid strides for the primitive m dimension (or n dimension if transpose)." << std::endl;
487+ return error_t ::err_invalid_strides;
488+ }
489+
398490 if (!isValidKDim (dim_types, exec_types, strides_in1, prim_main))
399491 {
400492 hasSetupError = true ;
@@ -412,6 +504,13 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
412504 std::cerr << " Error: Invalid stride configuration detected for unary. Expected k-dimension to have a stride of zero." << std::endl;
413505 return error_t ::err_invalid_strides;
414506 }
507+
508+ if (prim_last_touch != TensorConfig::prim_t ::none || prim_last_touch != TensorConfig::prim_t ::none)
509+ {
510+ hasSetupError = true ;
511+ std::cerr << " Error: A main 'Unary' primitive can not have first touch and last touch primitives." << std::endl;
512+ return error_t ::err_invalid_main_configuration;
513+ }
415514 }
416515 else if (isBrgemm (prim_main))
417516 {
@@ -448,7 +547,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
448547 first_touch.emplace <Unary>();
449548 TensorOperation::prim_first = prim_first_touch;
450549
451- Unary::error_t error = generateUnary (std::get<Unary>(first_touch), prim_first_touch, dim_sizes);
550+ Unary::error_t error = generateUnary (std::get<Unary>(first_touch), prim_first_touch, dim_sizes, false );
452551
453552 if (error != Unary::error_t ::success)
454553 {
@@ -517,7 +616,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
517616 main_kernel.emplace <Unary>();
518617 TensorOperation::prim_main = prim_main;
519618
520- Unary::error_t error = generateUnary (std::get<Unary>(main_kernel), prim_main, dim_sizes);
619+ Unary::error_t error = generateUnary (std::get<Unary>(main_kernel), prim_main, dim_sizes, isTranspose );
521620
522621 if (error != Unary::error_t ::success)
523622 {
@@ -541,7 +640,7 @@ mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup_no_optimizat
541640 last_touch.emplace <Unary>();
542641 TensorOperation::prim_last = prim_last_touch;
543642
544- Unary::error_t error = generateUnary (std::get<Unary>(last_touch), prim_last_touch, dim_sizes);
643+ Unary::error_t error = generateUnary (std::get<Unary>(last_touch), prim_last_touch, dim_sizes, false );
545644
546645 if (error != Unary::error_t ::success)
547646 {
@@ -663,7 +762,8 @@ void mini_jit::TensorOperation::execute_dimension(int64_t index_dim, char const
663762 if (std::holds_alternative<Unary>(main_kernel))
664763 {
665764 Unary::kernel_t kernel = std::get<Unary>(main_kernel).get_kernel ();
666- kernel (ptr_in0, ptr_out, strides_in0[indexPrimN], strides_out[indexPrimN]);
765+ int32_t indexLeadingDimension = isTranspose ? indexPrimM : indexPrimN;
766+ kernel (ptr_in0, ptr_out, strides_in0[indexPrimN], strides_out[indexLeadingDimension]);
667767 }
668768 else if (std::holds_alternative<Brgemm>(main_kernel))
669769 {
0 commit comments