Skip to content

Commit 2b2618e

Browse files
xadupregithub-code-quality[bot]justinchuby
authored
Add converter torch aten::histc (#2796)
Fixes #2794. --------- Signed-off-by: Xavier Dupré <xadupre@microsoft.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 182d83a commit 2b2618e

File tree

3 files changed

+203
-65
lines changed

3 files changed

+203
-65
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4608,12 +4608,53 @@ def aten_hinge_embedding_loss(
46084608
raise NotImplementedError()
46094609

46104610

4611+
@torch_op("aten::histc", trace_only=True)
46114612
def aten_histc(
46124613
self: TensorType, bins: int = 100, min: float = 0.0, max: float = 0.0
46134614
) -> TensorType:
46144615
"""histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor"""
4616+
if min == max:
4617+
# This ONNXScript implementation precomputes static bin edges and cannot
4618+
# faithfully reproduce torch.histc's dynamic behavior when min == max
4619+
# (including the default min=0, max=0, which infers the range from data).
4620+
raise NotImplementedError(
4621+
f"aten_histc with min == max ({min}) is not supported in this export path."
4622+
)
4623+
delta = (max - min) / (bins * 1.0)
4624+
values = [min + delta * i for i in range(bins + 1)]
46154625

4616-
raise NotImplementedError()
4626+
flat_self = op.Reshape(self, [-1])
4627+
computation_type = self.dtype
4628+
4629+
cond = op.And(
4630+
op.GreaterOrEqual(flat_self, op.CastLike([min], self)),
4631+
op.LessOrEqual(flat_self, op.CastLike([max], self)),
4632+
)
4633+
4634+
assert self.type.dtype not in {ir.DataType.INT32, ir.DataType.INT64}, (
4635+
f"torch.histc only works on float but {self.type.dtype=}"
4636+
)
4637+
4638+
cond = op.And(cond, op.Not(op.IsNaN(flat_self)))
4639+
# max is included.
4640+
dtype = self.type.dtype.numpy()
4641+
values = np.array(values, dtype=dtype)
4642+
values[-1] = np.nextafter(values[-1], np.array(np.inf, dtype=dtype), dtype=dtype)
4643+
typed_values = op.Constant(value=ir.tensor(values, dtype=self.type.dtype))
4644+
4645+
clipped = op.Where(cond, flat_self, op.CastLike([min - 1], self))
4646+
bins = op.Unsqueeze(typed_values, [1])
4647+
4648+
less = op.Cast(
4649+
op.Less(op.Unsqueeze(clipped, [0]), bins),
4650+
to=computation_type,
4651+
)
4652+
sums = op.ReduceSum(less, [1], keepdims=0)
4653+
res = op.Sub(
4654+
op.Slice(sums, [1], op.Shape(sums), [0]),
4655+
op.Slice(sums, [0], [-1], [0]),
4656+
)
4657+
return res
46174658

46184659

46194660
def aten_histogramdd(

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,51 @@ def forward(self, x):
914914
)
915915
_testing.assert_onnx_program(onnx_program)
916916

917+
def test_aten_histc_float(self):
918+
class Model(torch.nn.Module):
919+
def forward(self, x):
920+
return torch.histc(x, 3, 0, 2)
921+
922+
model = Model()
923+
onnx_program = torch.onnx.export(
924+
Model(),
925+
(torch.rand(10, 10, 10),),
926+
dynamo=True,
927+
verbose=False,
928+
dynamic_shapes=({0: "batch"},),
929+
)
930+
_testing.assert_onnx_program(onnx_program)
931+
932+
for k in range(101):
933+
with self.subTest(k=k):
934+
inputs = (torch.tensor([(k - 1) / 49.0], dtype=torch.float32),)
935+
expected = model(*inputs)
936+
got = onnx_program.call_reference({"x": inputs[0]})
937+
torch.testing.assert_close(expected, got[0])
938+
939+
@unittest.skip("see https://github.com/pytorch/pytorch/issues/174668")
940+
def test_aten_histc_float16(self):
941+
class Model(torch.nn.Module):
942+
def forward(self, x):
943+
return torch.histc(x, 60, -10, 10)
944+
945+
model = Model()
946+
onnx_program = torch.onnx.export(
947+
Model(),
948+
(torch.rand((10, 10, 10), dtype=torch.float16),),
949+
dynamo=True,
950+
verbose=False,
951+
dynamic_shapes=({0: "batch"},),
952+
)
953+
_testing.assert_onnx_program(onnx_program)
954+
955+
for k in range(101):
956+
with self.subTest(k=k):
957+
inputs = (torch.tensor([(k - 1) / 49.0], dtype=torch.float16),)
958+
expected = model(*inputs)
959+
got = onnx_program.call_reference({"x": inputs[0]})
960+
torch.testing.assert_close(expected, got[0])
961+
917962

918963
if __name__ == "__main__":
919964
unittest.main()

0 commit comments

Comments
 (0)