@@ -858,7 +858,7 @@ func.func @native_layer_norm_mixed_dtypes(%input: !torch.vtensor<[1,56,56,96],bf
858
858
// CHECK-DAG: %[[C3:.*]] = torch.constant.int 3
859
859
// CHECK-DAG: %[[C4:.*]] = torch.constant.int 4
860
860
// CHECK-DAG: %[[C5:.*]] = torch.constant.int 5
861
- // CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C2 ]], %[[C4 ]], %[[C3 ]], %[[C5 ]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
861
+ // CHECK: %[[PERMLIST:.*]] = torch.prim.ListConstruct %[[C0]], %[[C1]], %[[C3 ]], %[[C5 ]], %[[C2 ]], %[[C4 ]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
862
862
// CHECK: %[[EXPAND1:.*]] = torch.prims.split_dim %[[ARG0]], %[[C2]], %[[C2]] : !torch.vtensor<[1,8,4,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,4],f32>
863
863
// CHECK: %[[EXPAND2:.*]] = torch.prims.split_dim %[[EXPAND1]], %[[C4]], %[[C2]] : !torch.vtensor<[1,8,2,2,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,8,2,2,2,2],f32>
864
864
// CHECK: %[[PERMUTE:.*]] = torch.aten.permute %[[EXPAND2]], %[[PERMLIST]] : !torch.vtensor<[1,8,2,2,2,2],f32>, !torch.list<int> -> !torch.vtensor<[1,8,2,2,2,2],f32>
0 commit comments