Skip to content

Commit a348a7b

Browse files
committed
up
1 parent 07a057b commit a348a7b

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

backends/apple/coreml/runtime/delegate/multiarray.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,7 @@ class MultiArray final {
131131
*ptr = value;
132132
}
133133

134-
void resize(const std::vector<size_t>& shape) {
135-
layout_.resize(shape);
136-
}
134+
void resize(const std::vector<size_t>& shape) { layout_.resize(shape); }
137135

138136
private:
139137
void* data(const std::vector<size_t>& indices) const noexcept;

backends/apple/coreml/runtime/test/BackendDelegateTests.mm

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ - (void)testAddModelExecution {
162162
MLMultiArray *output = [ETCoreMLTestUtils filledMultiArrayWithShape:inputs[0].shape dataType:inputs[0].dataType repeatedValue:@(0) error:&localError];
163163
NSArray<MLMultiArray *> *args = [inputs arrayByAddingObject:output];
164164
std::error_code errorCode;
165+
auto argsVec = to_multiarrays(args);
165166
XCTAssertTrue(_delegate->execute(handle,
166-
to_multiarrays(args),
167+
argsVec,
167168
ModelLoggingOptions(),
168169
nullptr,
169170
errorCode));
@@ -187,8 +188,9 @@ - (void)testMulModelExecution {
187188
MLMultiArray *output = [ETCoreMLTestUtils filledMultiArrayWithShape:inputs[0].shape dataType:inputs[0].dataType repeatedValue:@(0) error:&localError];
188189
NSArray<MLMultiArray *> *args = [inputs arrayByAddingObject:output];
189190
std::error_code errorCode;
190-
XCTAssertTrue(_delegate->execute(handle,
191-
to_multiarrays(args),
191+
auto argsVec = to_multiarrays(args);
192+
XCTAssertTrue(_delegate->execute(handle,
193+
argsVec,
192194
ModelLoggingOptions(),
193195
nullptr,
194196
errorCode));

examples/apple/coreml/scripts/export.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def parse_args() -> argparse.ArgumentParser:
7676
parser.add_argument("--use_partitioner", action=argparse.BooleanOptionalAction)
7777
parser.add_argument("--generate_etrecord", action=argparse.BooleanOptionalAction)
7878
parser.add_argument("--save_processed_bytes", action=argparse.BooleanOptionalAction)
79+
parser.add_argument(
80+
"--dynamic_shapes",
81+
action=argparse.BooleanOptionalAction,
82+
required=False,
83+
default=False,
84+
)
7985

8086
args = parser.parse_args()
8187
# pyre-fixme[7]: Expected `ArgumentParser` but got `Namespace`.
@@ -164,16 +170,20 @@ def main():
164170
f"Valid compute units are {valid_compute_units}."
165171
)
166172

167-
model, example_inputs, _, _ = EagerModelFactory.create_model(
173+
model, example_inputs, _, dynamic_shapes = EagerModelFactory.create_model(
168174
*MODEL_NAME_TO_MODEL[args.model_name]
169175
)
176+
if not args.dynamic_shapes:
177+
dynamic_shapes = None
170178

171179
compile_specs = generate_compile_specs_from_args(args)
172180
lowered_module = None
173181

174182
if args.use_partitioner:
175183
model.eval()
176-
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
184+
exir_program_aten = torch.export.export(
185+
model, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
186+
)
177187

178188
edge_program_manager = exir.to_edge(exir_program_aten)
179189
edge_copy = copy.deepcopy(edge_program_manager)

0 commit comments

Comments
 (0)