diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 455d07fb4408a..23692478755c6 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -237,23 +237,24 @@ class Tosa_I32EnumAttr; def Tosa_PRO_INT : I32EnumAttrCase<"pro_int", 1>; def Tosa_PRO_FP : I32EnumAttrCase<"pro_fp", 2>; -def Tosa_NONE : I32EnumAttrCase<"none", 3>; -def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>; -def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>; -def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>; -def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>; -def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>; -def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>; -def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>; -def Tosa_EXT_NONE : I32EnumAttrCase<"none", 8>; +def Tosa_EXT_NONE : I32EnumAttrCase<"none", 0>; +def Tosa_EXT_INT16 : I32EnumAttrCase<"int16", 1>; +def Tosa_EXT_INT4 : I32EnumAttrCase<"int4", 2>; +def Tosa_EXT_BF16 : I32EnumAttrCase<"bf16", 3>; +def Tosa_EXT_FP8E4M3 : I32EnumAttrCase<"fp8e4m3", 4>; +def Tosa_EXT_FP8E5M2 : I32EnumAttrCase<"fp8e5m2", 5>; +def Tosa_EXT_FFT : I32EnumAttrCase<"fft", 6>; +def Tosa_EXT_VARIABLE : I32EnumAttrCase<"variable", 7>; +def Tosa_EXT_CONTROLFLOW : I32EnumAttrCase<"controlflow", 8>; def Tosa_ExtensionAttr : Tosa_I32EnumAttr<"Extension", "supported TOSA extensions", "ext", [ Tosa_EXT_INT16, Tosa_EXT_INT4, Tosa_EXT_BF16, Tosa_EXT_FP8E4M3, - Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_NONE + Tosa_EXT_FP8E5M2, Tosa_EXT_FFT, Tosa_EXT_VARIABLE, Tosa_EXT_CONTROLFLOW, Tosa_EXT_NONE ]>; def Tosa_ExtensionArrayAttr diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 8f822f930e164..454f3ee4db8af 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -2436,8 +2436,8 @@ def Tosa_IfOp : Tosa_Op<"cond_if", ); list availability = [ - Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, - Extension<[]>, + Profile<[]>, + Extension<[Tosa_EXT_CONTROLFLOW]>, ]; let regions = (region @@ -2477,8 +2477,8 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [ ); list availability = [ - Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>, - Extension<[]>, + Profile<[]>, + Extension<[Tosa_EXT_CONTROLFLOW]>, ]; let regions = (region diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h index 5de2a591e698d..064264e73c8af 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaProfileCompliance.h @@ -143,6 +143,7 @@ class TosaProfileCompliance { case Extension::fft: return {Profile::pro_fp}; case Extension::variable: + case Extension::controlflow: return {Profile::pro_fp, Profile::pro_int}; case Extension::none: return {}; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index f74a4b4c58b80..32648830bb760 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -425,7 +425,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { } else { llvm::errs() << "unknown TOSA extension name passed in: " << ext << ", supported extension are int16, int4, bf16, " - << "fp8e4m3, fp8e5m2, fft, and variable\n"; + << "fp8e4m3, fp8e5m2, fft, variable and controlflow\n"; return signalPassFailure(); } } diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index e66ff4cacfd89..da8f9ef82c839 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -629,8 +629,8 @@ func.func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { // ----- // CHECK-LABEL: cond_if func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // CHECK: profiles: [ [pro_int, pro_fp] ] - // CHECK: extensions: [ [bf16] ] + // CHECK: tosa.cond_if profiles: [ ] + // CHECK: tosa.cond_if extensions: [ [controlflow] ] %0 = tosa.cond_if %arg2 -> (tensor) { %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor tosa.yield %1 : tensor @@ -645,8 +645,8 @@ func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg1: tensor) { %0 = "tosa.const"() {value = dense<0> : tensor} : () -> tensor - // CHECK: profiles: [ [pro_int, pro_fp] ] - // CHECK: extensions: [ [bf16] ] + // CHECK: profiles: [ ] + // CHECK: extensions: [ [controlflow] ] %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) { %2 = tosa.greater_equal %arg3, %arg1 : (tensor, tensor) -> tensor %3 = tosa.logical_not %2 : (tensor) -> tensor diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 1aa8547cb2fdb..c44a0d1c09215 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -4,7 +4,7 @@ // validation flow. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment" func.func @test_const() -> tensor<1xf32> { // expected-error@+1{{'tosa.const' op expected same attr/result element types}} diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 046b9d5615074..684875f231dec 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -2,7 +2,7 @@ // Enable all supported profiles to focus the verification of expected extension requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp,mt strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int,pro_fp strict-op-spec-alignment" // ----- func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) { @@ -36,3 +36,37 @@ func.func @test_cast_bf16_i32(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xi32 return %0 : tensor<13x21x3xi32> } +// ----- +func.func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // expected-error@+1 {{'tosa.cond_if' op illegal: requires [controlflow]}} + %0 = tosa.cond_if %arg2 -> (tensor) { + %1 = tosa.add %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } else { + %1 = tosa.sub %arg0, %arg1 : (tensor, tensor) -> tensor + tosa.yield %1 : tensor + } + return %0 : tensor +} + +// ----- +func.func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + // expected-error@+1 {{'tosa.while_loop' op illegal: requires [controlflow]}} + %1:3 = tosa.while_loop (%arg2 = %0, %arg3 = %0, %arg4 = %arg0) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) { + %2 = tosa.greater_equal %arg3, %arg1 : (tensor, tensor) -> tensor + %3 = tosa.logical_not %2 : (tensor) -> tensor + tosa.yield %3 : tensor + } do { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %3 = tosa.add %arg3, %2 : (tensor, tensor) -> tensor + %7 = tosa.const_shape {value = dense<[1]> : tensor<1xindex>} : () -> !tosa.shape<1> + %4 = tosa.reshape %2, %7 : (tensor, !tosa.shape<1>) -> tensor<1xi32> + %5 = tosa.add %arg4, %4 : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> + %6 = tosa.add %arg2, %2 : (tensor, tensor) -> tensor + tosa.yield %6, %3, %5 : tensor, tensor, tensor<10xi32> + } + return +} + diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 90c4551564d1e..a75a6bee8e809 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -2,7 +2,7 @@ // Enable all supported profiles and extensions to focus the verification of expected level errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp,mt extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate="profile=pro_int,pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow" func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xi32> { // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} diff --git a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir index 6dddcf329d110..8183b58272e84 100644 --- a/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_all_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment" // ----- func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () { diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index c46b2543fbed5..f7cbd114280dc 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_int extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment" // ----- func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index 479b7569f54ae..1d6d33b9a02c7 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -2,7 +2,7 @@ // Enable all supported extensions to focus the verification of expected profile requirement errors. //-------------------------------------------------------------------------------------------------- -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable strict-op-spec-alignment" +// RUN: mlir-opt %s -split-input-file -verify-diagnostics -tosa-validate="profile=pro_fp extension=int16,int4,bf16,fp8e4m3,fp8e5m2,fft,variable,controlflow strict-op-spec-alignment" // ----- func.func @test_table(%arg0 : tensor<4x5xi8>, %arg1 : tensor<513xi8>) -> () {