@@ -783,3 +783,291 @@ func.func @test_const_shape() -> !tosa.shape<4> {
783783 %cst = tosa.const_shape {values = dense <1 > : tensor <4 xindex >} : () -> !tosa.shape <4 >
784784 return %cst : !tosa.shape <4 >
785785}
786+
787+ // F8 support tests
788+
789+ // -----
790+ // CHECK-LABEL: argmax_f8E5M2
791+ func.func @test_argmax_f8E5M2 (%arg0: tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 > {
792+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E5 M2 >) -> tensor <12 x16 xi32 >
793+ return %0 : tensor <12 x16 xi32 >
794+ }
795+
796+ // -----
797+ // CHECK-LABEL: avg_pool2d_f8E5M2
798+ func.func @test_avg_pool2d_f8E5M2 (%arg0: tensor <1 x7 x7 x9 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 > {
799+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
800+ %output_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
801+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x7 x7 x9 xf8 E5 M2 >
802+ return %0 : tensor <1 x7 x7 x9 xf8 E5 M2 >
803+ }
804+
805+ // -----
806+ // CHECK-LABEL: conv2d_f8E5M2
807+ func.func @test_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <8 x1 x1 x4 xf8 E5 M2 >, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
808+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
809+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E5 M2 >}> : () -> tensor <1 xf8 E5 M2 >
810+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <8 x1 x1 x4 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
811+ return %0 : tensor <1 x4 x4 x8 xf16 >
812+ }
813+
814+ // -----
815+ // CHECK-LABEL: conv3d_f8E5M2
816+ func.func @test_conv3d_f8E5M2 (%arg0: tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, %arg1: tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 > {
817+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E5 M2 >, tensor <34 x1 x1 x1 x17 xf8 E5 M2 >, tensor <34 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x8 x21 x34 xf16 >
818+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
819+ }
820+
821+ // -----
822+ // CHECK-LABEL: depthwise_conv2d_f8E5M2
823+ func.func @test_depthwise_conv2d_f8E5M2 (%arg0: tensor <1 x4 x4 x4 xf8 E5 M2 >, %arg1: tensor <1 x1 x4 x2 xf8 E5 M2 >, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 > {
824+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E5 M2 >, tensor <1 x1 x4 x2 xf8 E5 M2 >, tensor <8 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x4 x4 x8 xf16 >
825+ return %0 : tensor <1 x4 x4 x8 xf16 >
826+ }
827+
828+ // -----
829+ // CHECK-LABEL: test_matmul_f8E5M2
830+ func.func @test_matmul_f8E5M2 (%arg0: tensor <1 x14 x19 xf8 E5 M2 >, %arg1: tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 > {
831+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E5 M2 >, tensor <1 x19 x28 xf8 E5 M2 >) -> tensor <1 x14 x28 xf16 >
832+ return %0 : tensor <1 x14 x28 xf16 >
833+ }
834+
835+ // -----
836+ // CHECK-LABEL: max_pool2d_f8E5M2
837+ func.func @test_max_pool2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 > {
838+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >) -> tensor <1 x32 x32 x8 xf8 E5 M2 >
839+ return %0 : tensor <1 x32 x32 x8 xf8 E5 M2 >
840+ }
841+
842+ // -----
843+
844+ // CHECK-LABEL: transpose_conv2d_f8E5M2
845+ func.func @test_transpose_conv2d_f8E5M2 (%arg0: tensor <1 x32 x32 x8 xf8 E5 M2 >, %arg1: tensor <16 x1 x1 x8 xf8 E5 M2 >, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E5 M2 >, %arg4: tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 > {
846+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E5 M2 >, tensor <16 x1 x1 x8 xf8 E5 M2 >, tensor <16 xf16 >, tensor <1 xf8 E5 M2 >, tensor <1 xf8 E5 M2 >) -> tensor <1 x32 x32 x16 xf16 >
847+ return %0 : tensor <1 x32 x32 x16 xf16 >
848+ }
849+
850+ // -----
851+ // CHECK-LABEL: const_f8E5M2
852+ func.func @test_const_f8E5M2 (%arg0 : index ) -> tensor <4 xf8 E5 M2 > {
853+ %0 = " tosa.const" () {values = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E5 M2 >} : () -> tensor <4 xf8 E5 M2 >
854+ return %0 : tensor <4 xf8 E5 M2 >
855+ }
856+
857+ // -----
858+ // CHECK-LABEL: cast_f8E5M2
859+ func.func @test_cast_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 > {
860+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf16 >
861+ return %0 : tensor <13 x21 x3 xf16 >
862+ }
863+
864+ // -----
865+ // CHECK-LABEL: concat_f8E5M2
866+ func.func @test_concat_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 > {
867+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <26 x21 x3 xf8 E5 M2 >
868+ return %0 : tensor <26 x21 x3 xf8 E5 M2 >
869+ }
870+
871+ // -----
872+ // CHECK-LABEL: pad_f8E5M2
873+ func.func @test_pad_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
874+ %padding = tosa.const_shape {values = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
875+ %cst = " tosa.const" () { values = dense <-0.0 > : tensor <1 xf8 E5 M2 > } : () -> tensor <1 xf8 E5 M2 >
876+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <6 >, tensor <1 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
877+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
878+ }
879+
880+ // -----
881+ // CHECK-LABEL: reshape_f8E5M2
882+ func.func @test_reshape_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <1 x819 xf8 E5 M2 > {
883+ %1 = tosa.const_shape {values = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
884+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <2 >) -> tensor <1 x819 xf8 E5 M2 >
885+ return %0 : tensor <1 x819 xf8 E5 M2 >
886+ }
887+
888+ // -----
889+ // CHECK-LABEL: reverse_f8E5M2
890+ func.func @test_reverse_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
891+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
892+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
893+ }
894+
895+ // -----
896+ // CHECK-LABEL: slice_f8E5M2
897+ func.func @test_slice_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <4 x11 x1 xf8 E5 M2 > {
898+ %0 = tosa.const_shape {values = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
899+ %1 = tosa.const_shape {values = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
900+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E5 M2 >
901+ return %2 : tensor <4 x11 x1 xf8 E5 M2 >
902+ }
903+
904+ // -----
905+ // CHECK-LABEL: tile_f8E5M2
906+ func.func @test_tile_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <39 x21 x6 xf8 E5 M2 > {
907+ %cst = tosa.const_shape { values = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
908+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E5 M2 >, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E5 M2 >
909+ return %0 : tensor <39 x21 x6 xf8 E5 M2 >
910+ }
911+
912+ // -----
913+ func.func @test_transpose_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 > {
914+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E5 M2 >) -> tensor <3 x13 x21 xf8 E5 M2 >
915+ return %1 : tensor <3 x13 x21 xf8 E5 M2 >
916+ }
917+
918+ // -----
919+ // CHECK-LABEL: gather_f8E5M2
920+ func.func @test_gather_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 > {
921+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E5 M2 >
922+ return %0 : tensor <13 x26 x3 xf8 E5 M2 >
923+ }
924+
925+ // -----
926+ // CHECK-LABEL: scatter_f8E5M2
927+ func.func @test_scatter_f8E5M2 (%arg0: tensor <13 x21 x3 xf8 E5 M2 >, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 > {
928+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E5 M2 >, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E5 M2 >) -> tensor <13 x21 x3 xf8 E5 M2 >
929+ return %0 : tensor <13 x21 x3 xf8 E5 M2 >
930+ }
931+
932+ // -----
933+ // CHECK-LABEL: argmax_f8E4M3FN
934+ func.func @test_argmax_f8E4M3FN (%arg0: tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 > {
935+ %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor <12 x8 x16 xf8 E4 M3 FN>) -> tensor <12 x16 xi32 >
936+ return %0 : tensor <12 x16 xi32 >
937+ }
938+
939+ // -----
940+ // CHECK-LABEL: avg_pool2d_f8E4M3FN
941+ func.func @test_avg_pool2d_f8E4M3FN (%arg0: tensor <1 x7 x7 x9 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN> {
942+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
943+ %output_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
944+ %0 = tosa.avg_pool2d %arg0 , %input_zp , %output_zp {acc_type = f16 , kernel = array<i64 : 2 , 2 >, pad = array<i64 : 0 , 1 , 0 , 1 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x7 x7 x9 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x7 x7 x9 xf8 E4 M3 FN>
945+ return %0 : tensor <1 x7 x7 x9 xf8 E4 M3 FN>
946+ }
947+
948+ // -----
949+ // CHECK-LABEL: conv2d_f8E4M3FN
950+ func.func @test_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <8 x1 x1 x4 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >) -> tensor <1 x4 x4 x8 xf16 > {
951+ %input_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
952+ %weight_zp = " tosa.const" () <{values = dense <0.0 > : tensor <1 xf8 E4 M3 FN>}> : () -> tensor <1 xf8 E4 M3 FN>
953+ %0 = tosa.conv2d %arg0 , %arg1 , %arg2 , %input_zp , %weight_zp {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, local_bound = true } : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <8 x1 x1 x4 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
954+ return %0 : tensor <1 x4 x4 x8 xf16 >
955+ }
956+
957+ // -----
958+ // CHECK-LABEL: conv3d_f8E4M3FN
959+ func.func @test_conv3d_f8E4M3FN (%arg0: tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, %arg1: tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, %arg2: tensor <34 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 > {
960+ %0 = tosa.conv3d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >} : (tensor <1 x4 x8 x21 x17 xf8 E4 M3 FN>, tensor <34 x1 x1 x1 x17 xf8 E4 M3 FN>, tensor <34 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x8 x21 x34 xf16 >
961+ return %0 : tensor <1 x4 x8 x21 x34 xf16 >
962+ }
963+
964+ // -----
965+ // CHECK-LABEL: depthwise_conv2d_f8E4M3FN
966+ func.func @test_depthwise_conv2d_f8E4M3FN (%arg0: tensor <1 x4 x4 x4 xf8 E4 M3 FN>, %arg1: tensor <1 x1 x4 x2 xf8 E4 M3 FN>, %arg2: tensor <8 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 > {
967+ %0 = tosa.depthwise_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x4 x4 x4 xf8 E4 M3 FN>, tensor <1 x1 x4 x2 xf8 E4 M3 FN>, tensor <8 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x4 x4 x8 xf16 >
968+ return %0 : tensor <1 x4 x4 x8 xf16 >
969+ }
970+
971+ // -----
972+ // CHECK-LABEL: matmul_f8E4M3FN
973+ func.func @test_matmul_f8E4M3FN (%arg0: tensor <1 x14 x19 xf8 E4 M3 FN>, %arg1: tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 > {
974+ %0 = tosa.matmul %arg0 , %arg1 : (tensor <1 x14 x19 xf8 E4 M3 FN>, tensor <1 x19 x28 xf8 E4 M3 FN>) -> tensor <1 x14 x28 xf16 >
975+ return %0 : tensor <1 x14 x28 xf16 >
976+ }
977+
978+ // -----
979+ // CHECK-LABEL: max_pool2d_f8E4M3FN
980+ func.func @test_max_pool2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN> {
981+ %0 = tosa.max_pool2d %arg0 {kernel = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x8 xf8 E4 M3 FN>
982+ return %0 : tensor <1 x32 x32 x8 xf8 E4 M3 FN>
983+ }
984+
985+ // -----
986+ // CHECK-LABEL: transpose_conv2d_f8E4M3FN
987+ func.func @test_transpose_conv2d_f8E4M3FN (%arg0: tensor <1 x32 x32 x8 xf8 E4 M3 FN>, %arg1: tensor <16 x1 x1 x8 xf8 E4 M3 FN>, %arg2: tensor <16 xf16 >, %arg3: tensor <1 xf8 E4 M3 FN>, %arg4: tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 > {
988+ %0 = tosa.transpose_conv2d %arg0 , %arg1 , %arg2 , %arg3 , %arg4 {acc_type = f16 , out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x32 x32 x8 xf8 E4 M3 FN>, tensor <16 x1 x1 x8 xf8 E4 M3 FN>, tensor <16 xf16 >, tensor <1 xf8 E4 M3 FN>, tensor <1 xf8 E4 M3 FN>) -> tensor <1 x32 x32 x16 xf16 >
989+ return %0 : tensor <1 x32 x32 x16 xf16 >
990+ }
991+
992+ // -----
993+ // CHECK-LABEL: const_f8E4M3FN
994+ func.func @test_const_f8E4M3FN (%arg0 : index ) -> tensor <4 xf8 E4 M3 FN> {
995+ %0 = " tosa.const" () {values = dense <[3.0 , -0.0 , -1.0 , 2.0 ]> : tensor <4 xf8 E4 M3 FN>} : () -> tensor <4 xf8 E4 M3 FN>
996+ return %0 : tensor <4 xf8 E4 M3 FN>
997+ }
998+
999+ // -----
1000+ // CHECK-LABEL: cast_f8E4M3FN
1001+ func.func @test_cast_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 > {
1002+ %0 = tosa.cast %arg0 : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf16 >
1003+ return %0 : tensor <13 x21 x3 xf16 >
1004+ }
1005+
1006+ // -----
1007+ // CHECK-LABEL: concat_f8E4M3FN
1008+ func.func @test_concat_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN> {
1009+ %0 = tosa.concat %arg0 , %arg1 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <26 x21 x3 xf8 E4 M3 FN>
1010+ return %0 : tensor <26 x21 x3 xf8 E4 M3 FN>
1011+ }
1012+
1013+ // -----
1014+ // CHECK-LABEL: pad_f8E4M3FN
1015+ func.func @test_pad_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1016+ %padding = tosa.const_shape {values = dense <0 > : tensor <6 xindex >} : () -> !tosa.shape <6 >
1017+ %cst = " tosa.const" () { values = dense <-0.0 > : tensor <1 xf8 E4 M3 FN> } : () -> tensor <1 xf8 E4 M3 FN>
1018+ %0 = tosa.pad %arg0 , %padding , %cst : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <6 >, tensor <1 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1019+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1020+ }
1021+
1022+ // -----
1023+ // CHECK-LABEL: reshape_f8E4M3FN
1024+ func.func @test_reshape_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <1 x819 xf8 E4 M3 FN> {
1025+ %1 = tosa.const_shape {values = dense <[1 , 819 ]> : tensor <2 xindex >} : () -> !tosa.shape <2 >
1026+ %0 = tosa.reshape %arg0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <2 >) -> tensor <1 x819 xf8 E4 M3 FN>
1027+ return %0 : tensor <1 x819 xf8 E4 M3 FN>
1028+ }
1029+
1030+ // -----
1031+ // CHECK-LABEL: reverse_f8E4M3FN
1032+ func.func @test_reverse_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1033+ %0 = tosa.reverse %arg0 {axis = 0 : i32 } : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1034+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1035+ }
1036+
1037+ // -----
1038+ // CHECK-LABEL: slice_f8E4M3FN
1039+ func.func @test_slice_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <4 x11 x1 xf8 E4 M3 FN> {
1040+ %0 = tosa.const_shape {values = dense <[4 , 11 , 1 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1041+ %1 = tosa.const_shape {values = dense <[6 , 8 , 0 ]> : tensor <3 xindex >} : () -> !tosa.shape <3 >
1042+ %2 = tosa.slice %arg0 , %0 , %1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >, !tosa.shape <3 >) -> tensor <4 x11 x1 xf8 E4 M3 FN>
1043+ return %2 : tensor <4 x11 x1 xf8 E4 M3 FN>
1044+ }
1045+
1046+ // -----
1047+ // CHECK-LABEL: tile_f8E4M3FN
1048+ func.func @test_tile_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <39 x21 x6 xf8 E4 M3 FN> {
1049+ %cst = tosa.const_shape { values = dense <[3 , 1 , 2 ]> : tensor <3 xindex > } : () -> !tosa.shape <3 >
1050+ %0 = tosa.tile %arg0 , %cst: (tensor <13 x21 x3 xf8 E4 M3 FN>, !tosa.shape <3 >) -> tensor <39 x21 x6 xf8 E4 M3 FN>
1051+ return %0 : tensor <39 x21 x6 xf8 E4 M3 FN>
1052+ }
1053+
1054+ // -----
1055+ // CHECK-LABEL: transpose_f8E4M3FN
1056+ func.func @test_transpose_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN> {
1057+ %1 = tosa.transpose %arg0 {perms = array<i32 : 2 , 0 , 1 >} : (tensor <13 x21 x3 xf8 E4 M3 FN>) -> tensor <3 x13 x21 xf8 E4 M3 FN>
1058+ return %1 : tensor <3 x13 x21 xf8 E4 M3 FN>
1059+ }
1060+
1061+ // -----
1062+ // CHECK-LABEL: gather_f8E4M3FN
1063+ func.func @test_gather_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN> {
1064+ %0 = tosa.gather %arg0 , %arg1 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >) -> tensor <13 x26 x3 xf8 E4 M3 FN>
1065+ return %0 : tensor <13 x26 x3 xf8 E4 M3 FN>
1066+ }
1067+
1068+ // -----
1069+ // CHECK-LABEL: scatter_f8E4M3FN
1070+ func.func @test_scatter_f8E4M3FN (%arg0: tensor <13 x21 x3 xf8 E4 M3 FN>, %arg1: tensor <13 x26 xi32 >, %arg2: tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN> {
1071+ %0 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <13 x21 x3 xf8 E4 M3 FN>, tensor <13 x26 xi32 >, tensor <13 x26 x3 xf8 E4 M3 FN>) -> tensor <13 x21 x3 xf8 E4 M3 FN>
1072+ return %0 : tensor <13 x21 x3 xf8 E4 M3 FN>
1073+ }
0 commit comments