Skip to content

Commit 0b78412

Browse files
authored
Ensure we can call custom ops from torch cadence lib
Differential Revision: D81738196 Pull Request resolved: #14034
1 parent 57173d9 commit 0b78412

File tree

4 files changed

+67
-88
lines changed

4 files changed

+67
-88
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ python_unittest(
614614
typing = True,
615615
deps = [
616616
":typing_stubs",
617+
"//executorch/backends/cadence/aot:ops_registrations",
617618
"//executorch/backends/cadence/aot:ref_implementations",
618619
"//caffe2:torch",
619620
]

backends/cadence/aot/ops_registrations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,7 +1449,7 @@ def quantized_layer_norm_meta(
14491449
input: torch.Tensor,
14501450
X_scale: torch.Tensor,
14511451
X_zero_point: torch.Tensor,
1452-
normalized_shape: int,
1452+
normalized_shape: list[int],
14531453
weight: torch.Tensor,
14541454
bias: torch.Tensor,
14551455
eps: float,
@@ -1464,7 +1464,7 @@ def quantized_layer_norm_per_tensor_meta(
14641464
input: torch.Tensor,
14651465
X_scale: float,
14661466
X_zero_point: int,
1467-
normalized_shape: int,
1467+
normalized_shape: list[int],
14681468
weight: torch.Tensor,
14691469
bias: torch.Tensor,
14701470
eps: float,

backends/cadence/aot/ref_implementations.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def quantize_per_tensor(
6464
f"Unsupported dtype to quantize to. Supported dtypes must be one of {supported_quant_types}"
6565
)
6666

67-
dequantized = torch.round(input_tensor * scale + zero_point).to(dtype)
67+
quantized = torch.round(input_tensor * scale + zero_point).to(dtype)
6868
return torch.max(
69-
torch.min(dequantized, torch.tensor(quant_max)),
69+
torch.min(quantized, torch.tensor(quant_max)),
7070
torch.tensor(quant_min),
7171
)
7272

@@ -247,12 +247,12 @@ def quantized_linear(
247247
).reshape(*leading_dims, N)
248248

249249

250-
@impl(m, "quantized_layer_norm_per_tensor")
250+
@impl(m, "quantized_layer_norm.per_tensor")
251251
def quantized_layer_norm_per_tensor(
252252
input_tensor: torch.Tensor,
253253
X_scale: float,
254254
X_zero_point: int,
255-
normalized_shape: int,
255+
normalized_shape: list[int],
256256
weight: torch.Tensor,
257257
bias: torch.Tensor,
258258
eps: float,
@@ -283,7 +283,7 @@ def quantized_layer_norm_per_tensor(
283283
input_tensor, X_scale, X_zero_point, -128, 127, torch.float32
284284
)
285285
out = torch.nn.functional.layer_norm(
286-
float_input_tensor, (normalized_shape,), weight, bias, eps=eps
286+
float_input_tensor, normalized_shape, weight, bias, eps=eps
287287
)
288288

289289
return quantize_per_tensor(
@@ -365,7 +365,7 @@ def quantized_conv_per_tensor(
365365
)
366366

367367

368-
@impl(m, "quantized_conv_nchw_per_tensor")
368+
@impl(m, "quantized_conv_nchw.per_tensor")
369369
def quantized_conv_nchw_per_tensor(
370370
input_tensor: torch.Tensor,
371371
weight: torch.Tensor,
@@ -421,7 +421,7 @@ def quantized_conv_nchw_per_tensor(
421421
)
422422

423423

424-
@impl(m, "quantized_conv_nhwc_per_tensor")
424+
@impl(m, "quantized_conv_nhwc.per_tensor")
425425
def quantized_conv_nhwc_per_tensor(
426426
input_tensor: torch.Tensor,
427427
weight: torch.Tensor,
@@ -558,62 +558,62 @@ def variant(
558558
return decorator
559559

560560

561-
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor")
561+
@impl(m, "quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor")
562562
@quantized_conv_variant("nchw", torch.int8, torch.int8)
563563
def quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
564564

565565

566-
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor")
566+
@impl(m, "quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor")
567567
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
568568
def quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
569569

570570

571-
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor")
571+
@impl(m, "quantized_conv_nhwc_asym8sxsym8s_asym8s.per_tensor")
572572
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
573573
def quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
574574

575575

576-
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor")
576+
@impl(m, "quantized_conv_nhwc_asym8uxsym8u_asym8u.per_tensor")
577577
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
578578
def quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
579579

580580

581-
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor")
581+
@impl(m, "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor")
582582
@quantized_conv_variant("nchw", torch.int8, torch.int8)
583583
def quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
584584

585585

586-
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor")
586+
@impl(m, "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor")
587587
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
588588
def quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
589589

590590

591-
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor")
591+
@impl(m, "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s.per_tensor")
592592
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
593593
def quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
594594

595595

596-
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor")
596+
@impl(m, "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor")
597597
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
598598
def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
599599

600600

601-
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor")
601+
@impl(m, "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor")
602602
@quantized_conv_variant("nchw", torch.int8, torch.int8)
603603
def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
604604

605605

606-
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor")
606+
@impl(m, "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor")
607607
@quantized_conv_variant("nchw", torch.uint8, torch.uint8)
608608
def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
609609

610610

611-
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor")
611+
@impl(m, "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor")
612612
@quantized_conv_variant("nhwc", torch.int8, torch.int8)
613613
def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...
614614

615615

616-
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor")
616+
@impl(m, "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor")
617617
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8)
618618
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...
619619

0 commit comments

Comments
 (0)