-
Notifications
You must be signed in to change notification settings - Fork 33
Select broadcast Iota to concat broadcast slice #2365
Copy link
Copy link
Open
Description
func.func @main(%arg0: tensor<1520x3056xf64>, %arg1: tensor<1520x3056xf64>) -> tensor<1520x3056xf64> {
%c = stablehlo.constant {enzymexla.non_negative = [#enzymexla<guaranteed NOTGUARANTEED>]} dense<-1519> : tensor<1520xi64>
%c_0 = stablehlo.constant {enzymexla.non_negative = [#enzymexla<guaranteed GUARANTEED>]} dense<0> : tensor<1520xi64>
%0 = stablehlo.iota dim = 0 {enzymexla.non_negative = [#enzymexla<guaranteed GUARANTEED>]} : tensor<1520xi64>
%1 = stablehlo.add %0, %c {enzymexla.non_negative = [#enzymexla<guaranteed NOTGUARANTEED>]} : tensor<1520xi64>
%2 = stablehlo.compare EQ, %1, %c_0 : (tensor<1520xi64>, tensor<1520xi64>) -> tensor<1520xi1>
%3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<1520xi1>) -> tensor<1520x3056xi1>
%4 = stablehlo.select %3, %arg0, %arg1 : tensor<1520x3056xi1>, tensor<1520x3056xf64>
return %4 : tensor<1520x3056xf64>
}Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels