Skip to content

Commit a2f9587

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Custom where_Scalar op
Summary: Continued support of custom cadence ops Differential Revision: D82703256
1 parent eaaa76e commit a2f9587

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,3 +1108,15 @@ def rope(
11081108
[x0 * cos_tensor - x1 * sin_tensor, x0 * sin_tensor + x1 * cos_tensor], dim=-1
11091109
)
11101110
return rotated
1111+
1112+
1113+
@impl(m, "where_Scalar")
1114+
def where_Scalar(
1115+
condition: torch.Tensor,
1116+
if_true: float,
1117+
if_false: float,
1118+
) -> torch.Tensor:
1119+
if condition.dtype != torch.bool:
1120+
raise ValueError("condition must be a bool tensor")
1121+
1122+
return torch.where(condition, if_true, if_false)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,3 +1256,14 @@ def test_rope(
12561256
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
12571257
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
12581258
)
1259+
1260+
def test_where_Scalar(self) -> None:
1261+
input_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.int8)
1262+
out = torch.ops.cadence.where_Scalar(input_tensor > 2, 1.0, 0.0)
1263+
self.assertTrue(
1264+
torch.equal(out, torch.tensor([0.0, 0.0, 1.0, 1.0], dtype=torch.float32))
1265+
)
1266+
with self.assertRaises(ValueError) as context:
1267+
torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0)
1268+
1269+
self.assertIn("condition must be a bool tensor", str(context.exception))

0 commit comments

Comments
 (0)