Skip to content

Commit 7e5b37c

Browse files
noobsiecoderYifanShenSZyifan_shen3
authored
Fix Issue #2583: Dynamic padding in torch.nn.functional.pad (#2588)
* Fix Issue #2583: Dynamic padding in torch.nn.functional.pad Modified _array_construct to handle dynamic padding values: Creates proper Var objects using mb.concat instead of Python lists + Fixes AttributeError when converting models with x.size(-1) padding * limit torch to older than 2.8 for now (#2591) Co-authored-by: yifan_shen3 <[email protected]> * Add RMSNorm operator support for PyTorch to CoreML conversion (#2585) (#2592) * Add RMSNorm operator support for PyTorch to CoreML conversion (#2585) * formatted code * handles FP16 overflow for RMSNorm operation * handle dynamic padding w/o breaking legacy code --------- Co-authored-by: Yifan Shen <[email protected]> Co-authored-by: yifan_shen3 <[email protected]>
1 parent 02450cf commit 7e5b37c

File tree

2 files changed

+234
-4
lines changed

2 files changed

+234
-4
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -903,9 +903,6 @@ def _array_construct(context, node, array_type):
903903
const = mb.const(val=val, name=node.name)
904904
context.add(const)
905905
else:
906-
# If at least one input to the construct op is non-const, collect
907-
# the inputs and add them directly to the context. Ops that use this
908-
# node's output will take the list directly as input.
909906
context.add(array_type(inputs), node.name)
910907

911908

@@ -2387,7 +2384,54 @@ def _parse_keyword_args(context, node, mode: Var, value: Var) -> Tuple[Var]:
23872384
return mode, value
23882385

23892386
def _translate_torch_args(pad: Var, mode: Var, value: Var) -> Tuple[Var]:
2390-
if pad.val is not None:
2387+
# Check if pad is a list (which happens when `_array_construct` returns
2388+
# a list for dynamic values). When _array_construct returns a list,
2389+
# it means at least one value is dynamic (not compile-time constant).
2390+
if isinstance(pad, list):
2391+
# NOTE:
2392+
# - CoreML's `mb.pad` operation only supports dynamic padding for
2393+
# 1D tensors.
2394+
# - For n-dimensional tensors (2D, 3D, 4D, etc.), dynamic padding
2395+
# values cause runtime errors even when formatted correctly.
2396+
# - This is a fundamental limitation of the CoreML framework,
2397+
# not this converter.
2398+
if len(pad) == 2 and x.rank == 1:
2399+
tensor_inputs = []
2400+
for inp in pad:
2401+
if isinstance(inp, (int, float)):
2402+
# Convert plain number to const Var
2403+
const_var = mb.const(val=[inp])
2404+
tensor_inputs.append(const_var)
2405+
elif isinstance(inp, Var):
2406+
if len(inp.shape) == 0: # Scalar Var
2407+
# Convert scalar to 1D tensor
2408+
tensor_inp = mb.expand_dims(x=inp, axes=[0])
2409+
tensor_inputs.append(tensor_inp)
2410+
else:
2411+
tensor_inputs.append(inp)
2412+
else:
2413+
tensor_inputs.append(inp)
2414+
# Concatenate into a single tensor Var with shape (n,)
2415+
pad = mb.concat(values=tensor_inputs, axis=0)
2416+
else:
2417+
# Dynamic padding for n-dimensional tensors is not supported
2418+
# by CoreML
2419+
# This includes:
2420+
# - 1D padding on multi-dimensional tensors.
2421+
# (e.g., padding only last dim of 2D tensor)
2422+
# - Multi-dimensional padding with any dynamic values
2423+
#
2424+
# Although it works for n-dimension when used with
2425+
# MIL operations (reshape, reverse, concat),
2426+
# CoreML's mb.pad operation fails at runtime when given
2427+
# dynamic padding for n-dimensional tensors.
2428+
raise NotImplementedError(
2429+
f"Dynamic padding for n-dimensional tensors is not " \
2430+
f"supported. " \
2431+
f"Received {len(pad)} padding values. " \
2432+
f"Only 1D dynamic padding (2 values) is supported."
2433+
)
2434+
elif pad.val is not None:
23912435
# torch.nn.functional.pad has different semantics from Core ML
23922436
# * for torch.nn.functional.pad
23932437
# x.shape[-1] = padding[0] + x.shape[-1] + padding[1]
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright (c) 2020, Apple Inc. All rights reserved.
2+
#
3+
# Use of this source code is governed by a BSD-3-clause license that can be
4+
# found in the LICENSE.txt file or at
5+
# https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
# Test suite for dynamic padding conversion (Issue #2583)
8+
# These tests verify the fix for converting PyTorch pad operations with
9+
# runtime-determined padding values to Core ML.
10+
# The issue occurred in _translate_torch_args() when handling
11+
# dynamic padding values like (1, x.size(-1)).
12+
13+
import pytest
14+
from coremltools._deps import _HAS_TORCH
15+
import numpy as np
16+
17+
# Check if pytorch module is installed
18+
# Also, check if pytorch and coremltools' versions are compatible for this test
19+
if _HAS_TORCH:
20+
import torch
21+
import coremltools as ct
22+
23+
# get package versions
24+
torch_major = int(torch.__version__.split('.')[0])
25+
ct_version_parts = ct.__version__.split('.')
26+
ct_major = int(ct_version_parts[0])
27+
28+
# Run only on PyTorch 2.x and coremltools >= 8.x
29+
_TORCH_COMPATIBLE = torch_major >= 2
30+
_CT_COMPATIBLE = ct_major >= 8
31+
_VERSIONS_COMPATIBLE = _TORCH_COMPATIBLE and _CT_COMPATIBLE
32+
else:
33+
_VERSIONS_COMPATIBLE = False
34+
35+
36+
@pytest.mark.skipif(not _HAS_TORCH, reason="PyTorch not found")
37+
@pytest.mark.skipif(not _VERSIONS_COMPATIBLE, reason="Incompatible versions")
38+
class TestPadDynamicFix:
39+
"""
40+
Test dynamic padding fix for Issue #2583 - torch.nn.functional.pad
41+
with x.size(-1)
42+
"""
43+
44+
@staticmethod
45+
@pytest.mark.parametrize(
46+
"input_size, pad_fn, expected_size, test_name",
47+
[
48+
# Dynamic padding tests
49+
(3, lambda x: (1, x.size(-1)), 7, "dynamic_right"),
50+
(5, lambda x: (0, x.size(-1)), 10, "dynamic_right_only"),
51+
(4, lambda x: (x.size(-1), 0), 8, "dynamic_left_only"),
52+
(2, lambda x: (x.size(-1), x.size(-1)), 6, "both_dynamic"),
53+
]
54+
)
55+
def test_dynamic_padding(input_size, pad_fn, expected_size, test_name):
56+
"""
57+
Test dynamic padding cases where pad values depend on input size
58+
"""
59+
class TestModel(torch.nn.Module):
60+
def forward(self, x):
61+
return torch.nn.functional.pad(x, pad_fn(x))
62+
63+
model = TestModel()
64+
example = torch.rand(input_size)
65+
traced = torch.jit.trace(model, example)
66+
67+
mlmodel = ct.convert(
68+
traced,
69+
inputs=[ct.TensorType(
70+
shape=ct.EnumeratedShapes(
71+
shapes=[[2], [3], [4], [5], [input_size]],
72+
default=[input_size],
73+
),
74+
dtype=np.float32,
75+
name="input"
76+
)],
77+
outputs=[ct.TensorType(name="output", dtype=np.float32)],
78+
convert_to="mlprogram"
79+
)
80+
81+
result = mlmodel.predict({"input": example.numpy()})
82+
assert result["output"].shape[0] == expected_size, \
83+
f"Test '{test_name}' failed: expected shape ({expected_size},)," \
84+
f"got {result['output'].shape}"
85+
86+
@staticmethod
87+
@pytest.mark.parametrize(
88+
"input_size,pad_fn,expected_size,test_name",
89+
[
90+
# Constant padding tests (regression test)
91+
(3, lambda x: (1, 2), 6, "both_constant"),
92+
(4, lambda x: (0, 3), 7, "constant_right_only"),
93+
(5, lambda x: (2, 0), 7, "constant_left_only"),
94+
(2, lambda x: (3, 4), 9, "large_constants"),
95+
]
96+
)
97+
def test_constant_padding(input_size, pad_fn, expected_size, test_name):
98+
"""
99+
Test constant padding cases - regression test
100+
"""
101+
class TestModel(torch.nn.Module):
102+
def forward(self, x):
103+
return torch.nn.functional.pad(x, pad_fn(x))
104+
105+
model = TestModel()
106+
example = torch.rand(input_size)
107+
traced = torch.jit.trace(model, example)
108+
109+
mlmodel = ct.convert(
110+
traced,
111+
inputs=[ct.TensorType(
112+
shape=ct.EnumeratedShapes(
113+
shapes=[[2], [3], [4], [5], [input_size]],
114+
default=[input_size],
115+
),
116+
dtype=np.float32,
117+
name="input"
118+
)],
119+
outputs=[ct.TensorType(name="output", dtype=np.float32)],
120+
convert_to="mlprogram"
121+
)
122+
123+
result = mlmodel.predict({"input": example.numpy()})
124+
output = result["output"]
125+
126+
# Verify shape
127+
assert output.shape[0] == expected_size, \
128+
f"Test '{test_name}' failed: expected shape ({expected_size},)," \
129+
f"got {output.shape}"
130+
131+
# Verify padding values are zeros
132+
pad_config = pad_fn(example)
133+
left_pad, right_pad = pad_config
134+
135+
if left_pad > 0:
136+
assert np.allclose(output[:left_pad], 0.0), \
137+
f"Test '{test_name}' failed: left padding should be zeros"
138+
139+
assert np.allclose(
140+
output[left_pad:left_pad+input_size], example.numpy()
141+
), \
142+
f"Test '{test_name}' failed: original values not preserved"
143+
144+
if right_pad > 0:
145+
assert np.allclose(output[-right_pad:], 0.0), \
146+
f"Test '{test_name}' failed: right padding should be zeros"
147+
148+
@staticmethod
149+
@pytest.mark.parametrize(
150+
"input_size,pad_fn,expected_size,test_name",
151+
[
152+
# Mixed padding tests
153+
(3, lambda x: (2, x.size(-1)), 8, "constant_left_dynamic_right"),
154+
(4, lambda x: (x.size(-1), 3), 11, "dynamic_left_constant_right"),
155+
]
156+
)
157+
def test_mixed_padding(input_size, pad_fn, expected_size, test_name):
158+
"""
159+
Test mixed padding cases with both constant and dynamic values
160+
"""
161+
class TestModel(torch.nn.Module):
162+
def forward(self, x):
163+
return torch.nn.functional.pad(x, pad_fn(x))
164+
165+
model = TestModel()
166+
example = torch.rand(input_size)
167+
traced = torch.jit.trace(model, example)
168+
169+
mlmodel = ct.convert(
170+
traced,
171+
inputs=[ct.TensorType(
172+
shape=ct.EnumeratedShapes(
173+
shapes=[[2], [3], [4], [5], [input_size]],
174+
default=[input_size],
175+
),
176+
dtype=np.float32,
177+
name="input"
178+
)],
179+
outputs=[ct.TensorType(name="output", dtype=np.float32)],
180+
convert_to="mlprogram"
181+
)
182+
183+
result = mlmodel.predict({"input": example.numpy()})
184+
assert result["output"].shape[0] == expected_size, \
185+
f"Test '{test_name}' failed: expected shape ({expected_size},)," \
186+
f"got {result['output'].shape}"

0 commit comments

Comments
 (0)