Skip to content

Commit 38bd90b

Browse files
Copilotjustinchuby
andcommitted
Implement aten_bilinear function using Einsum operation
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent d70d4be commit 38bd90b

File tree

3 files changed

+57
-1
lines changed

3 files changed

+57
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,7 @@ def aten_bernoulli_p(self: TTensor, p: float) -> TTensor:
11891189
return op.CastLike(sampled, self)
11901190

11911191

1192+
@torch_op("aten::bilinear", trace_only=True)
11921193
def aten_bilinear(
11931194
input1: TensorType,
11941195
input2: TensorType,
@@ -1197,7 +1198,23 @@ def aten_bilinear(
11971198
) -> TensorType:
11981199
"""bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor"""
11991200

1200-
raise NotImplementedError()
1201+
# Bilinear transformation: y = x1^T A x2 + b
1202+
# input1 shape: (..., in1_features)
1203+
# input2 shape: (..., in2_features)
1204+
# weight shape: (out_features, in1_features, in2_features)
1205+
# bias shape: (out_features) - optional
1206+
# output shape: (..., out_features)
1207+
1208+
# Use Einsum to compute the bilinear transformation
1209+
# "...i,oij,...j->...o" means:
1210+
# - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
1211+
result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o")
1212+
1213+
# Add bias if provided
1214+
if bias is not None:
1215+
result = op.Add(result, bias)
1216+
1217+
return result
12011218

12021219

12031220
def aten_binary_cross_entropy_with_logits(

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@ def sample_inputs_scalar_tensor(op_info, device, dtype, requires_grad, **kwargs)
3737
yield opinfo_core.SampleInput(item, dtype=dtype)
3838

3939

40+
def sample_inputs_bilinear(op_info, device, dtype, requires_grad, **kwargs):
41+
"""Sample inputs for bilinear operation."""
42+
del op_info
43+
del kwargs
44+
45+
make_arg = functools.partial(
46+
torch_testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
47+
)
48+
49+
# Test cases: (batch_size, in1_features, in2_features, out_features)
50+
cases = [
51+
(2, 3, 4, 5), # Basic case
52+
(1, 2, 2, 1), # Minimal case
53+
(3, 5, 7, 4), # Different dimensions
54+
(2, 1, 1, 3), # Single input features
55+
]
56+
57+
for batch_size, in1_features, in2_features, out_features in cases:
58+
input1 = make_arg((batch_size, in1_features))
59+
input2 = make_arg((batch_size, in2_features))
60+
weight = make_arg((out_features, in1_features, in2_features))
61+
bias = make_arg((out_features,))
62+
63+
# Test with bias
64+
yield opinfo_core.SampleInput(input1, args=(input2, weight, bias))
65+
66+
# Test without bias (only for first case to avoid too many tests)
67+
if batch_size == 2:
68+
yield opinfo_core.SampleInput(input1, args=(input2, weight, None))
69+
70+
4071
def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs):
4172
del op_info
4273

@@ -2180,6 +2211,13 @@ def __init__(self):
21802211
# To avoid name duplication, it is possible to rename the OpInfo and specify
21812212
# the `op` field explicitly.
21822213
OP_DB: List[opinfo_core.OpInfo] = [
2214+
opinfo_core.OpInfo(
2215+
"bilinear",
2216+
op=torch.nn.functional.bilinear,
2217+
dtypes=common_dtype.floating_types(),
2218+
sample_inputs_func=sample_inputs_bilinear,
2219+
supports_out=False,
2220+
),
21832221
opinfo_core.OpInfo(
21842222
"ops.aten.bernoulli.p",
21852223
aten_name="bernoulli.p",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def _where_input_wrangler(
657657
),
658658
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm, tolerance={torch.float16: (1e-3, 1e-2)}),
659659
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
660+
TorchLibOpInfo("bilinear", core_ops.aten_bilinear, tolerance={torch.float32: (1e-4, 1e-4)}),
660661
TorchLibOpInfo(
661662
# This string is a unique ID. In extra_opinfo.py, we
662663
# also define test data for this ID with

0 commit comments

Comments
 (0)