Skip to content

Commit 3862537

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Custom where_Scalar op (#14470)
Summary: Continued support of custom cadence ops Reviewed By: hsharma35 Differential Revision: D82703256
1 parent b991271 commit 3862537

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

backends/cadence/aot/ref_implementations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,3 +1092,15 @@ def rms_norm(
10921092
eps: float,
10931093
) -> torch.Tensor:
10941094
return W * nn.RMSNorm(list(normalized_shape), eps=eps, dtype=X.dtype)(X)
1095+
1096+
1097+
@impl(m, "where_Scalar")
1098+
def where_Scalar(
1099+
condition: torch.Tensor,
1100+
if_true: float,
1101+
if_false: float,
1102+
) -> torch.Tensor:
1103+
if condition.dtype != torch.bool:
1104+
raise ValueError("condition must be a bool tensor")
1105+
1106+
return torch.where(condition, if_true, if_false)

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,3 +1145,14 @@ def test_quantized_relu(
11451145
torch.equal(output, expected_output),
11461146
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
11471147
)
1148+
1149+
def test_where_Scalar(self) -> None:
1150+
input_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.int8)
1151+
out = torch.ops.cadence.where_Scalar(input_tensor > 2, 1.0, 0.0)
1152+
self.assertTrue(
1153+
torch.equal(out, torch.tensor([0.0, 0.0, 1.0, 1.0], dtype=torch.float32))
1154+
)
1155+
with self.assertRaises(ValueError) as context:
1156+
torch.ops.cadence.where_Scalar(input_tensor, 1.0, 0.0)
1157+
1158+
self.assertIn("condition must be a bool tensor", str(context.exception))

0 commit comments

Comments
 (0)