Skip to content

Commit 5651a87

Browse files
authored
enable fp32 cat slice
Differential Revision: D70216001 Pull Request resolved: pytorch/executorch#8759
1 parent 2d0cf64 commit 5651a87

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

backends/cadence/fusion_g3/operators/op_cat.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ Tensor& cat_out(
115115
(out.scalar_type() == ScalarType::Char) ||
116116
(out.scalar_type() == ScalarType::UInt32) ||
117117
(out.scalar_type() == ScalarType::UInt16) ||
118-
(out.scalar_type() == ScalarType::Byte)) {
118+
(out.scalar_type() == ScalarType::Byte) ||
119+
(out.scalar_type() == ScalarType::Float)) {
119120
XT_KERNEL_CHECK(
120121
ctx,
121122
out,

backends/cadence/fusion_g3/operators/op_slice_copy.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ Tensor& slice_copy_Tensor_out(
101101
(out.scalar_type() == ScalarType::Char) ||
102102
(out.scalar_type() == ScalarType::UInt32) ||
103103
(out.scalar_type() == ScalarType::UInt16) ||
104-
(out.scalar_type() == ScalarType::Byte))) {
104+
(out.scalar_type() == ScalarType::Byte) ||
105+
(out.scalar_type() == ScalarType::Float))) {
105106
XT_KERNEL_CHECK(
106107
ctx,
107108
out,
@@ -132,4 +133,4 @@ Tensor& slice_copy_Tensor_out(
132133
} // namespace native
133134
} // namespace G3
134135
} // namespace impl
135-
} // namespace cadence
136+
} // namespace cadence

backends/cadence/fusion_g3/operators/xt_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ inline int get_element_size(ScalarType dtype) {
1919
return sizeof(short);
2020
} else if ((dtype == ScalarType::Char) || (dtype == ScalarType::Byte)) {
2121
return sizeof(char);
22+
} else if (dtype == ScalarType::Float) {
23+
return sizeof(float);
2224
}
2325
return 0;
2426
}

backends/cadence/utils/facto_util.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
# seed to generate identical cases every run to reproduce from bisect
2020
random_manager.seed(1729)
21+
MAX_CASES = 50
2122

2223

2324
def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> None:
@@ -46,6 +47,14 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
4647
cp.Value.Le(lambda deps, dtype, struct: 2**2),
4748
]
4849
)
50+
case "slice_copy.Tensor":
51+
tensor_constraints.extend(
52+
[
53+
cp.Rank.Le(lambda deps: 2),
54+
cp.Value.Ge(lambda deps, dtype, struct: 1),
55+
cp.Value.Le(lambda deps, dtype, struct: 2),
56+
]
57+
)
4958
case _:
5059
tensor_constraints.extend(
5160
[
@@ -124,4 +133,4 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
124133
return [
125134
(posargs, inkwargs)
126135
for posargs, inkwargs, _ in ArgumentTupleGenerator(spec).gen()
127-
]
136+
][:MAX_CASES]

0 commit comments

Comments
 (0)