Skip to content

Commit cac10b8

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Fix NJT OpInfo entry for nn.functional.prelu (pytorch#144582)
Part of my BE project addressing NJT bugs surfaced via OpInfo tests. The OpInfo entry for prelu was wrong before this PR; `weight` needs to be passed as well. The op isn't fully implemented yet. Pull Request resolved: pytorch#144582 Approved by: https://github.com/soulitzer
1 parent eaef613 commit cac10b8

File tree

2 files changed

+37
-7
lines changed

2 files changed

+37
-7
lines changed

test/test_nestedtensor.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7909,6 +7909,7 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
79097909
# unary
79107910
# needs log_sigmoid_forward, which returns a tuple
79117911
"nn.functional.logsigmoid",
7912+
"nn.functional.prelu",
79127913
# needs rrelu_with_noise
79137914
"nn.functional.rrelu",
79147915
# binary
@@ -7972,13 +7973,6 @@ def __torch_dispatch__(self, func, types, args=..., kwargs=None):
79727973
},
79737974
name="no_masked_jagged_support",
79747975
),
7975-
# Need to adjust sample input func to pass the right thing
7976-
XFailRule(
7977-
error_type=TypeError,
7978-
error_msg="missing 1 required positional arguments",
7979-
op_match_fn=lambda device, op: (op.full_name == "nn.functional.prelu"),
7980-
name="invalid_prelu_sample_input_func",
7981-
),
79827976
# Op doesn't support lengths being present
79837977
XFailRule(
79847978
error_type=ValueError,

torch/testing/_internal/opinfo/definitions/nested.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,41 @@ def sample_inputs_nn_functional_linear(op_info, device, dtype, requires_grad, **
12171217
)
12181218

12191219

1220+
def sample_inputs_nn_functional_prelu(op_info, device, dtype, requires_grad, **kwargs):
1221+
for njt in _sample_njts(
1222+
device=device, dtype=dtype, requires_grad=requires_grad, dims=[3, 4]
1223+
):
1224+
# Second dim is interpreted as number of channels; this should be non-ragged for now
1225+
num_channels = njt.size(1)
1226+
if is_nested_int(num_channels):
1227+
continue
1228+
1229+
# 1D weight
1230+
weight = torch.randn(
1231+
num_channels,
1232+
device=device,
1233+
dtype=dtype,
1234+
requires_grad=requires_grad,
1235+
)
1236+
1237+
yield SampleInput(
1238+
_clone(njt),
1239+
kwargs={
1240+
"weight": _clone(weight),
1241+
},
1242+
name=f"{_describe_njt(njt)}: 1D weight",
1243+
)
1244+
1245+
# scalar tensor weight
1246+
yield SampleInput(
1247+
_clone(njt),
1248+
kwargs={
1249+
"weight": torch.tensor(4.2, device=device, dtype=dtype),
1250+
},
1251+
name=f"{_describe_njt(njt)}: scalar tensor weight",
1252+
)
1253+
1254+
12201255
def sample_inputs_nn_functional_rms_norm(
12211256
op_info, device, dtype, requires_grad, **kwargs
12221257
):
@@ -1412,6 +1447,7 @@ def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
14121447
"nn.functional.embedding": sample_inputs_nn_functional_embedding,
14131448
"nn.functional.embedding_bag": sample_inputs_nn_functional_embedding_bag,
14141449
"nn.functional.linear": sample_inputs_nn_functional_linear,
1450+
"nn.functional.prelu": sample_inputs_nn_functional_prelu,
14151451
"nn.functional.rms_norm": sample_inputs_nn_functional_rms_norm,
14161452
"nn.functional.threshold": sample_inputs_nn_functional_threshold,
14171453
**{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)},

0 commit comments

Comments
 (0)