@@ -370,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
370370                                    result.operands )))
371371    return  failure ();
372372
373-   result.addTypes (fnTy.getResult ( 0 ));
373+   result.addTypes (fnTy.getResults ( ));
374374  result.addAttributes (attrs);
375375
376376  return  success ();
@@ -532,6 +532,24 @@ void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
532532  printWithEnumHandling (parser, *this );
533533}
534534
535+ ParseResult CastFromBlockScaledOp::parse (OpAsmParser &parser,
536+                                          OperationState &result) {
537+   return  parseWithEnumHandling<tosa::BlockSize>(parser, result);
538+ }
539+ 
540+ void  CastFromBlockScaledOp::print (OpAsmPrinter &parser) {
541+   printWithEnumHandling (parser, *this );
542+ }
543+ 
544+ ParseResult CastToBlockScaledOp::parse (OpAsmParser &parser,
545+                                        OperationState &result) {
546+   return  parseWithEnumHandling<tosa::BlockSize>(parser, result);
547+ }
548+ 
549+ void  CastToBlockScaledOp::print (OpAsmPrinter &parser) {
550+   printWithEnumHandling (parser, *this );
551+ }
552+ 
535553// ===----------------------------------------------------------------------===//
536554//  Tosa utilities.
537555// ===----------------------------------------------------------------------===//
@@ -3944,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
39443962  return  success ();
39453963}
39463964
3965+ LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents (
3966+     MLIRContext *context, ::std::optional<Location> location,
3967+     CastFromBlockScaledOp::Adaptor adaptor,
3968+     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3969+   const  ShapeAdaptor inputShape (adaptor.getInputData ().getType ());
3970+   inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
3971+   return  success ();
3972+ }
3973+ 
3974+ LogicalResult CastFromBlockScaledOp::verify () {
3975+   const  Type inputDataType = getInputData ().getType ();
3976+   const  Type outputDataType = getResult ().getType ();
3977+   if  (failed (verifyCompatibleShape (inputDataType, outputDataType)))
3978+     return  emitOpError () << " require compatible shapes for input_data (" 
3979+                          << inputDataType << " ) and " 
3980+                          << " output_data (" " )" 
3981+ 
3982+   const  ShapeAdaptor inputDataShape = ShapeAdaptor (inputDataType);
3983+ 
3984+   if  (inputDataShape.hasRank ()) {
3985+     const  unsigned  int  blockSize =
3986+         BlockSizeAttr::getBlockSizeValue (getBlockSize ());
3987+     const  int64_t  inputDataLastDim =
3988+         inputDataShape.getDimSize (inputDataShape.getRank () - 1 );
3989+     if  (inputDataLastDim % blockSize != 0 )
3990+       return  emitOpError () << " expect last dimension of input_data (" 
3991+                            << inputDataLastDim
3992+                            << " ) to be divisible by block_size (" 
3993+                            << " )" 
3994+ 
3995+     const  Type inputScaleType = getInputScale ().getType ();
3996+     const  ShapeAdaptor inputScaleShape = ShapeAdaptor (inputScaleType);
3997+ 
3998+     if  (inputScaleShape.hasRank ()) {
3999+       SmallVector<int64_t > inputDataDims, inputScaleDims;
4000+       inputDataShape.getDims (inputDataDims);
4001+       inputScaleShape.getDims (inputScaleDims);
4002+ 
4003+       if  (inputDataDims.size () != inputScaleDims.size () ||
4004+           failed (verifyCompatibleShape (
4005+               ArrayRef<int64_t >(inputDataDims).drop_back (1 ),
4006+               ArrayRef<int64_t >(inputScaleDims).drop_back (1 ))))
4007+         return  emitOpError () << " require compatible shapes for input_data (" 
4008+                              << inputDataType << " ) and " 
4009+                              << " input_scale (" 
4010+                              << " ) except for the last dimension" 
4011+ 
4012+       const  SmallVector<int64_t , 2 > dimsToCheck{inputDataLastDim / blockSize,
4013+                                                 inputScaleDims.back ()};
4014+       if  (ShapedType::isStatic (inputDataLastDim) &&
4015+           failed (verifyCompatibleDims (dimsToCheck)))
4016+         return  emitOpError ()
4017+                << " expect last dimension of input_scale (" 
4018+                << inputScaleDims.back ()
4019+                << " ) to be equal to last dimension of input_data / block_size (" 
4020+                << inputDataDims.back () / blockSize << " )" 
4021+     }
4022+   }
4023+ 
4024+   return  success ();
4025+ }
4026+ 
4027+ LogicalResult CastToBlockScaledOp::inferReturnTypeComponents (
4028+     MLIRContext *context, ::std::optional<Location> location,
4029+     CastToBlockScaledOp::Adaptor adaptor,
4030+     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4031+   const  ShapeAdaptor inputShape (adaptor.getInputData ().getType ());
4032+   inferredReturnShapes.push_back (ShapedTypeComponents (inputShape));
4033+   if  (!inputShape.hasRank ())
4034+     return  success ();
4035+ 
4036+   //  Calculate output_scale shape if ranked input provided
4037+   SmallVector<int64_t > outputScaleShape;
4038+   inputShape.getDims (outputScaleShape);
4039+   const  int64_t  lastDimLoc = inputShape.getRank () - 1 ;
4040+   const  int64_t  lastDimSize = inputShape.getDimSize (lastDimLoc);
4041+   if  (ShapedType::isStatic (lastDimSize)) {
4042+     const  unsigned  int  blockSize =
4043+         BlockSizeAttr::getBlockSizeValue (adaptor.getBlockSize ());
4044+     outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
4045+   }
4046+   inferredReturnShapes.push_back (ShapedTypeComponents (outputScaleShape));
4047+   return  success ();
4048+ }
4049+ 
4050+ LogicalResult CastToBlockScaledOp::verify () {
4051+   const  Type inputDataType = getInputData ().getType ();
4052+   const  Type outputDataType = getResult (0 ).getType ();
4053+   if  (failed (verifyCompatibleShape (inputDataType, outputDataType)))
4054+     return  emitOpError () << " require compatible shapes for input_data (" 
4055+                          << inputDataType << " ) and " 
4056+                          << " output_data (" " )" 
4057+ 
4058+   const  unsigned  int  blockSize =
4059+       BlockSizeAttr::getBlockSizeValue (getBlockSize ());
4060+   const  ShapeAdaptor inputDataShape = ShapeAdaptor (inputDataType);
4061+   if  (inputDataShape.hasRank ()) {
4062+     const  int64_t  inputDataLastDim =
4063+         inputDataShape.getDimSize (inputDataShape.getRank () - 1 );
4064+     if  (ShapedType::isStatic (inputDataLastDim) &&
4065+         inputDataLastDim % blockSize != 0 )
4066+       return  emitOpError () << " expect last dimension of input_data (" 
4067+                            << inputDataLastDim
4068+                            << " ) to be divisible by block_size (" 
4069+                            << " )" 
4070+   }
4071+ 
4072+   const  ShapeAdaptor outputDataShape = ShapeAdaptor (outputDataType);
4073+   const  Type outputScaleType = getResult (1 ).getType ();
4074+   const  ShapeAdaptor outputScaleShape = ShapeAdaptor (outputScaleType);
4075+   if  (outputDataShape.hasRank () && outputScaleShape.hasRank ()) {
4076+     SmallVector<int64_t > outputDataDims, outputScaleDims;
4077+     outputDataShape.getDims (outputDataDims);
4078+     outputScaleShape.getDims (outputScaleDims);
4079+ 
4080+     if  (outputDataDims.size () != outputScaleDims.size () ||
4081+         failed (verifyCompatibleShape (
4082+             ArrayRef<int64_t >(outputDataDims).drop_back (1 ),
4083+             ArrayRef<int64_t >(outputScaleDims).drop_back (1 ))))
4084+       return  emitOpError () << " require compatible shapes for output_data (" 
4085+                            << outputDataType << " ) and " 
4086+                            << " output_scale (" 
4087+                            << " ) except for the last dimension" 
4088+ 
4089+     const  int64_t  outputDataLastDim = outputDataDims.back ();
4090+     const  SmallVector<int64_t , 2 > dimsToCheck{outputDataLastDim / blockSize,
4091+                                               outputScaleDims.back ()};
4092+     if  (ShapedType::isStatic (outputDataLastDim) &&
4093+         failed (verifyCompatibleDims (dimsToCheck)))
4094+       return  emitOpError ()
4095+              << " expect last dimension of output_scale (" 
4096+              << outputScaleDims.back ()
4097+              << " ) to be equal to last dimension of output_data / block_size (" 
4098+              << outputDataDims.back () / blockSize << " )" 
4099+   }
4100+ 
4101+   return  success ();
4102+ }
4103+ 
39474104LogicalResult IfOp::inferReturnTypeComponents (
39484105    MLIRContext *context, ::std::optional<Location> location,
39494106    IfOp::Adaptor adaptor,
0 commit comments