@@ -787,3 +787,291 @@ func.func @test_const_shape() -> !tosa.shape<4> {
787787 %cst = tosa.const_shape {values = dense <1 > : tensor <4 xindex >} : () -> !tosa.shape <4 >
788788 return %cst : !tosa.shape <4 >
789789}
790+
791+ // F8 support tests
792+
793+ // -----
794+ // CHECK-LABEL: argmax_f8E5M2
795+ func.func @test_argmax_f8E5M2 (%arg0: tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 > {
796+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 >
797+ return %0 : tensor <12 x16 xi32 >
798+ }
799+
800+ // -----
801+ // CHECK-LABEL: avg_pool2d_f8E5M2
802+ func.func @test_avg_pool2d_f8E5M2 (%arg0: tensor <1 x7 x7 x9 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 > {
803+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
804+ %output_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
805+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 >
806+ return %0 : tensor <1 x7 x7 x9 xf8 E5 M2 >
807+ }
808+
809+ // -----
810+ // CHECK-LABEL: conv2d_f8E5M2
811+ func.func @test_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <8 x1 x1 x4 xf8 E5 M2 >, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
812+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
813+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
814+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <8 x1 x1 x4 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
815+ return %0 : tensor <1 x4 x4 x8 xf16 >
816+ }
817+
818+ // -----
819+ // CHECK-LABEL: conv3d_f8E5M2
820+ func.func @test_conv3d_f8E5M2 (%arg0: tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, %arg1: tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 > {
821+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, tensor <34 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 >
822+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
823+ }
824+
825+ // -----
826+ // CHECK-LABEL: depthwise_conv2d_f8E5M2
827+ func.func @test_depthwise_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <1 x1 x4 x2 xf8 E5 M2 >, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 > {
828+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <1 x1 x4 x2 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
829+ return %0 : tensor <1 x4 x4 x8 xf16 >
830+ }
831+
832+ // -----
833+ // CHECK-LABEL: test_matmul_f8E5M2
834+ func.func @test_matmul_f8E5M2 (%arg0: tensor <1 x14 x19 xf8 E5 M2 >, %arg1: tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 > {
835+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E5 M2 >, tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 >
836+ return %0 : tensor <1 x14 x28 xf16 >
837+ }
838+
839+ // -----
840+ // CHECK-LABEL: max_pool2d_f8E5M2
841+ func.func @test_max_pool2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 > {
842+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 >
843+ return %0 : tensor <1 x32 x32 x8 xf8 E5 M2 >
844+ }
845+
846+ // -----
847+
848+ // CHECK-LABEL: transpose_conv2d_f8E5M2
849+ func.func @test_transpose_conv2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >, %arg1: tensor <16 x1 x1 x8 xf8 E5 M2 >, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 > {
850+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >, tensor <16 x1 x1 x8 xf8 E5 M2 >, tensor <16 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 >
851+ return %0 : tensor <1 x32 x32 x16 xf16 >
852+ }
853+
854+ // -----
855+ // CHECK-LABEL: const_f8E5M2
856+ func.func @test_const_f8E5M2 (%arg0 : index ) -> tensor <4 xf8 E5 M2 > {
857+ %0 = " tosa.const" () {values = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E5 M2 >} : () -> tensor <4 xf8 E5 M2 >
858+ return %0 : tensor <4 xf8 E5 M2 >
859+ }
860+
861+ // -----
862+ // CHECK-LABEL: cast_f8E5M2
863+ func.func @test_cast_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 > {
864+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 >
865+ return %0 : tensor <13 x21 x3 xf16 >
866+ }
867+
868+ // -----
869+ // CHECK-LABEL: concat_f8E5M2
870+ func.func @test_concat_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 > {
871+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 >
872+ return %0 : tensor <26 x21 x3 xf8 E5 M2 >
873+ }
874+
875+ // -----
876+ // CHECK-LABEL: pad_f8E5M2
877+ func.func @test_pad_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
878+ %padding = tosa.const_shape {values = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
879+ %cst = " tosa.const" () { values = dense <-0.0 > : tensor <1 xf8 E5 M2 > } : () -> tensor <1 xf8 E5 M2 >
880+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <6 >, tensor <1 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
881+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
882+ }
883+
884+ // -----
885+ // CHECK-LABEL: reshape_f8E5M2
886+ func.func @test_reshape_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <1 x819 xf8 E5 M2 > {
887+ %1 = tosa.const_shape {values = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
888+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <2 >) -> tensor <1 x819 xf8 E5 M2 >
889+ return %0 : tensor <1 x819 xf8 E5 M2 >
890+ }
891+
892+ // -----
893+ // CHECK-LABEL: reverse_f8E5M2
894+ func.func @test_reverse_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
895+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
896+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
897+ }
898+
899+ // -----
900+ // CHECK-LABEL: slice_f8E5M2
901+ func.func @test_slice_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <4 x11 x1 xf8 E5 M2 > {
902+ %0 = tosa.const_shape {values = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
903+ %1 = tosa.const_shape {values = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
904+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E5 M2 >
905+ return %2 : tensor <4 x11 x1 xf8 E5 M2 >
906+ }
907+
908+ // -----
909+ // CHECK-LABEL: tile_f8E5M2
910+ func.func @test_tile_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <39 x21 x6 xf8 E5 M2 > {
911+ %cst = tosa.const_shape { values = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
912+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E5 M2 >
913+ return %0 : tensor <39 x21 x6 xf8 E5 M2 >
914+ }
915+
916+ // -----
917+ func.func @test_transpose_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 > {
918+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 >
919+ return %1 : tensor <3 x13 x21 xf8 E5 M2 >
920+ }
921+
922+ // -----
923+ // CHECK-LABEL: gather_f8E5M2
924+ func.func @test_gather_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 > {
925+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 >
926+ return %0 : tensor <13 x26 x3 xf8 E5 M2 >
927+ }
928+
929+ // -----
930+ // CHECK-LABEL: scatter_f8E5M2
931+ func.func @test_scatter_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
932+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
933+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
934+ }
935+
936+ // -----
937+ // CHECK-LABEL: argmax_f8E4M3FN
938+ func.func @test_argmax_f8E4M3FN (%arg0: tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 > {
939+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 >
940+ return %0 : tensor <12 x16 xi32 >
941+ }
942+
943+ // -----
944+ // CHECK-LABEL: avg_pool2d_f8E4M3FN
945+ func.func @test_avg_pool2d_f8E4M3FN (%arg0: tensor <1 x7 x7 x9 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN> {
946+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
947+ %output_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
948+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN>
949+ return %0 : tensor <1 x7 x7 x9 xf8 E4 M3 FN>
950+ }
951+
952+ // -----
953+ // CHECK-LABEL: conv2d_f8E4M3FN
954+ func.func @test_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <8 x1 x1 x4 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
955+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
956+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
957+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <8 x1 x1 x4 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
958+ return %0 : tensor <1 x4 x4 x8 xf16 >
959+ }
960+
961+ // -----
962+ // CHECK-LABEL: conv3d_f8E4M3FN
963+ func.func @test_conv3d_f8E4M3FN (%arg0: tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, %arg1: tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 > {
964+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, tensor <34 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 >
965+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
966+ }
967+
968+ // -----
969+ // CHECK-LABEL: depthwise_conv2d_f8E4M3FN
970+ func.func @test_depthwise_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <1 x1 x4 x2 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 > {
971+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <1 x1 x4 x2 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
972+ return %0 : tensor <1 x4 x4 x8 xf16 >
973+ }
974+
975+ // -----
976+ // CHECK-LABEL: matmul_f8E4M3FN
977+ func.func @test_matmul_f8E4M3FN (%arg0: tensor <1 x14 x19 xf8 E4 M3 FN>, %arg1: tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 > {
978+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E4 M3 FN>, tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 >
979+ return %0 : tensor <1 x14 x28 xf16 >
980+ }
981+
982+ // -----
983+ // CHECK-LABEL: max_pool2d_f8E4M3FN
984+ func.func @test_max_pool2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN> {
985+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN>
986+ return %0 : tensor <1 x32 x32 x8 xf8 E4 M3 FN>
987+ }
988+
989+ // -----
990+ // CHECK-LABEL: transpose_conv2d_f8E4M3FN
991+ func.func @test_transpose_conv2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>, %arg1: tensor <16 x1 x1 x8 xf8 E4 M3 FN>, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 > {
992+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>, tensor <16 x1 x1 x8 xf8 E4 M3 FN>, tensor <16 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 >
993+ return %0 : tensor <1 x32 x32 x16 xf16 >
994+ }
995+
996+ // -----
997+ // CHECK-LABEL: const_f8E4M3FN
998+ func.func @test_const_f8E4M3FN (%arg0 : index ) -> tensor <4 xf8 E4 M3 FN> {
999+ %0 = " tosa.const" () {values = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E4 M3 FN>} : () -> tensor <4 xf8 E4 M3 FN>
1000+ return %0 : tensor <4 xf8 E4 M3 FN>
1001+ }
1002+
1003+ // -----
1004+ // CHECK-LABEL: cast_f8E4M3FN
1005+ func.func @test_cast_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 > {
1006+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 >
1007+ return %0 : tensor <13 x21 x3 xf16 >
1008+ }
1009+
1010+ // -----
1011+ // CHECK-LABEL: concat_f8E4M3FN
1012+ func.func @test_concat_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN> {
1013+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN>
1014+ return %0 : tensor <26 x21 x3 xf8 E4 M3 FN>
1015+ }
1016+
1017+ // -----
1018+ // CHECK-LABEL: pad_f8E4M3FN
1019+ func.func @test_pad_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1020+ %padding = tosa.const_shape {values = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
1021+ %cst = " tosa.const" () { values = dense <-0.0 > : tensor <1 xf8 E4 M3 FN> } : () -> tensor <1 xf8 E4 M3 FN>
1022+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <6 >, tensor <1 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1023+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1024+ }
1025+
1026+ // -----
1027+ // CHECK-LABEL: reshape_f8E4M3FN
1028+ func.func @test_reshape_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <1 x819 xf8 E4 M3 FN> {
1029+ %1 = tosa.const_shape {values = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
1030+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <2 >) -> tensor <1 x819 xf8 E4 M3 FN>
1031+ return %0 : tensor <1 x819 xf8 E4 M3 FN>
1032+ }
1033+
1034+ // -----
1035+ // CHECK-LABEL: reverse_f8E4M3FN
1036+ func.func @test_reverse_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1037+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1038+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1039+ }
1040+
1041+ // -----
1042+ // CHECK-LABEL: slice_f8E4M3FN
1043+ func.func @test_slice_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <4 x11 x1 xf8 E4 M3 FN> {
1044+ %0 = tosa.const_shape {values = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1045+ %1 = tosa.const_shape {values = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1046+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E4 M3 FN>
1047+ return %2 : tensor <4 x11 x1 xf8 E4 M3 FN>
1048+ }
1049+
1050+ // -----
1051+ // CHECK-LABEL: tile_f8E4M3FN
1052+ func.func @test_tile_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <39 x21 x6 xf8 E4 M3 FN> {
1053+ %cst = tosa.const_shape { values = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
1054+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E4 M3 FN>
1055+ return %0 : tensor <39 x21 x6 xf8 E4 M3 FN>
1056+ }
1057+
1058+ // -----
1059+ // CHECK-LABEL: transpose_f8E4M3FN
1060+ func.func @test_transpose_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN> {
1061+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN>
1062+ return %1 : tensor <3 x13 x21 xf8 E4 M3 FN>
1063+ }
1064+
1065+ // -----
1066+ // CHECK-LABEL: gather_f8E4M3FN
1067+ func.func @test_gather_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN> {
1068+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN>
1069+ return %0 : tensor <13 x26 x3 xf8 E4 M3 FN>
1070+ }
1071+
1072+ // -----
1073+ // CHECK-LABEL: scatter_f8E4M3FN
1074+ func.func @test_scatter_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1075+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1076+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1077+ }
0 commit comments