Skip to content

Commit 80ed456

Browse files
mcr229facebook-github-bot
authored andcommitted
handle static ints and floats (#140)
Summary: Pull Request resolved: #140 In order to support MV3 which has decomposed hardswish and hardsigmoid Decomp rules for both: ### Hardswish https://www.internalfb.com/code/fbsource/[9368f8417bd843ee8c91e24ac616ed7f4b194ed8]/xplat/caffe2/torch/_decomp/decompositions.py?lines=182-185 ### Hardsigmoid https://www.internalfb.com/code/fbsource/[9368f8417bd843ee8c91e24ac616ed7f4b194ed8]/xplat/caffe2/torch/_decomp/decompositions.py?lines=159-162 ### Fixing Zero-Dim tensors Both of these decompositions produce zero-dim tensors in the graph ( The + 3 and the / 6). This breaks for XNNPACK because it does not have zero-dim tensors. Instead if the static data is zero dim, then we will interpret it as [1]. #### Fixing torch.int64 static data In the decomposition 3 is converted via to_copy(torch.float32). However 6 remains as an int64. XNNPACK does not handle non-quantized integers, so we also cast all static data that is not quantized to float32 values. Reviewed By: digantdesai Differential Revision: D48667679 fbshipit-source-id: 0b20363bde480a98349e65ae8f569c7c95c95ef6
1 parent e425377 commit 80ed456

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def define_tensor(
230230
# Get new xnn id for tensor value
231231
ext_id, id_out, flag = self.gen_ids_and_flags(tensor, xnn_graph, quant_params)
232232
dims = get_shape(tensor)
233+
dims = [1] if len(dims) == 0 else dims
233234

234235
# constant values serialize data
235236
buffer_idx = self.get_serialized_buffer(
@@ -336,6 +337,10 @@ def get_serialized_buffer(
336337
# Quantize buffer if static data is indeed quantized
337338
if quant_params is not None and not quant_params.is_dynamic:
338339
const_val = quant_params.quantize_tensor(const_val).contiguous()
340+
else:
341+
# ensure that the const is fp32
342+
const_val = const_val.to(dtype=torch.float32).contiguous()
343+
339344
if swap_nc_for_depthwise_weights:
340345
const_val = const_val.permute(
341346
dims=((1, 0) + tuple(range(2, const_val.dim())))

backends/xnnpack/partition/configs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737

3838
SUPPORTED_MODULES = [
3939
torch.nn.Conv1d,
40+
# TODO(T161981984) recomposed hardswish into a single node
41+
torch.nn.Hardswish,
42+
torch.nn.Hardsigmoid,
4043
torch.nn.Conv2d,
4144
torch.nn.ReLU,
4245
torch.nn.Sigmoid,

0 commit comments

Comments
 (0)