Skip to content

Commit 636f9a7

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Linear (#1901)
Summary: Pull Request resolved: #1901 Add fp16 Linear Reviewed By: kimishpatel Differential Revision: D53333693 fbshipit-source-id: c6c27049f2de1ef6b53ffae040282bcc2285357d
1 parent 6bdd250 commit 636f9a7

File tree

3 files changed

+36
-12
lines changed

3 files changed

+36
-12
lines changed

backends/xnnpack/operators/op_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def define_node(
5959
xnn_graph,
6060
vals_to_ids,
6161
quant_params=weight_quant_params,
62+
fp32_static_weights=True,
6263
)
6364
filter_id = vals_to_ids[weight_node]
6465

@@ -73,6 +74,7 @@ def define_node(
7374
xnn_graph,
7475
vals_to_ids,
7576
quant_params=bias_quant_params,
77+
fp32_static_weights=True,
7678
)
7779
bias_id = vals_to_ids[bias_node]
7880
else:

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,23 +256,25 @@ Error defineTensor(
256256
/*flags=*/0, // this is netiher external input or output
257257
/*id_out=*/&id);
258258

259-
// this is the FP32 external value that is dynamically quantized
260-
uint32_t fp32_id;
259+
// this is the FP16 or FP32 external value that is being dynamically
260+
// quantized
261+
uint32_t float_id;
262+
enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype());
261263
status = xnn_define_tensor_value(
262264
/*subgraph=*/subgraph_ptr,
263-
/*datatype=*/xnn_datatype_fp32, // always fp32
265+
/*datatype=*/fp_datatype,
264266
/*num_dims=*/tensor_value->num_dims(),
265267
/*dims=*/dims_data.data(),
266268
/*data=*/buffer_ptr,
267269
/*external_id=*/tensor_value->external_id(),
268270
/*flags=*/tensor_value->flags(),
269-
/*id_out=*/&fp32_id);
270-
executor->addDynamicQinput(fp32_id);
271+
/*id_out=*/&float_id);
272+
executor->addDynamicQinput(float_id);
271273

272-
// Define dynamic conversion from fp32 to qdint8
274+
// Define dynamic conversion from float to qdint8
273275
status = xnn_define_convert(
274276
/*subgraph=*/subgraph_ptr,
275-
/*input_id=*/fp32_id,
277+
/*input_id=*/float_id,
276278
/*output_id=*/id,
277279
/*flags=*/0);
278280
break;

backends/xnnpack/test/ops/linear.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323

2424

2525
class TestLinear(unittest.TestCase):
26+
def test_fp16_linear(self):
27+
for use_bias in (True, False):
28+
self._test_linear(
29+
lambda in_size, out_size: torch.nn.Linear(
30+
in_size, out_size, bias=use_bias # noqa
31+
),
32+
uses_bias=use_bias,
33+
dtype=torch.float16,
34+
atol=5e-2,
35+
)
36+
2637
def test_fp32_linear(self):
2738
for use_bias in (True, False):
2839
self._test_linear(
@@ -284,7 +295,14 @@ def forward(self, x):
284295
quant=True,
285296
)
286297

287-
def _test_linear(self, make_module, uses_bias, quant=False):
298+
def _test_linear(
299+
self,
300+
make_module,
301+
uses_bias,
302+
quant=False,
303+
dtype: torch.dtype = torch.float,
304+
atol=1e-03,
305+
):
288306
aten_op, edge_op = (
289307
(
290308
"aten.addmm.default",
@@ -309,9 +327,10 @@ def _test_linear(self, make_module, uses_bias, quant=False):
309327
in_size = int(in_sizes[i])
310328
input_size = int(input_sizes[i])
311329
output_size = int(output_sizes[i])
330+
print(f"Testing {in_size} {input_size} {output_size}")
312331

313-
module = make_module(input_size, output_size).eval()
314-
inputs = (torch.randn(in_size, input_size),)
332+
module = make_module(input_size, output_size).eval().to(dtype)
333+
inputs = (torch.randn(in_size, input_size).to(dtype),)
315334

316335
tester = Tester(module, inputs)
317336

@@ -336,7 +355,8 @@ def _test_linear(self, make_module, uses_bias, quant=False):
336355
tester.to_executorch()
337356
tester.serialize()
338357
tester.run_method()
339-
tester.compare_outputs(qtol=quant)
358+
tester.compare_outputs(qtol=quant, atol=atol)
359+
print("success")
340360

341361
def _test_dqlinear(
342362
self,
@@ -370,7 +390,7 @@ def _test_dqlinear(
370390
tester.export()
371391
tester.check_count({aten_op: linear_count})
372392
tester.check(["torch.ops.quantized_decomposed"])
373-
393+
tester.dump_artifact()
374394
tester.to_edge()
375395
tester.check_count({edge_op: linear_count})
376396

0 commit comments

Comments
 (0)