Skip to content

Commit 810253f

Browse files
committed
Update base for Update on "[executorch][flat_tensor] Serialize flat tensor"
Serialize a flat tensor file. The resulting file looks like: Header with - flatbuffer offset and size - segment data offset and size Flatbuffer Tensor data (in segment) Differential Revision: [D66374253](https://our.internmc.facebook.com/intern/diff/D66374253/) [ghstack-poisoned]
2 parents 4586c68 + 51a107a commit 810253f

File tree

9 files changed

+242
-30
lines changed

9 files changed

+242
-30
lines changed

backends/cadence/fusion_g3/operators/op_mean.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ int prepare_data(
5959
return num_axis_dims;
6060
}
6161

62-
Tensor& mean_dim_out(
62+
Tensor& mean_out(
6363
KernelRuntimeContext& ctx,
6464
const Tensor& in,
6565
optional<ArrayRef<int64_t>> dim_list,
@@ -199,4 +199,4 @@ Tensor& mean_dim_out(
199199
} // namespace native
200200
} // namespace G3
201201
} // namespace impl
202-
} // namespace cadence
202+
} // namespace cadence

extension/llm/modules/test/test_attention.py

Lines changed: 129 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def setUp(self):
3333
self.num_kv_heads = 8
3434
self.head_dim = 64
3535
self.max_seq_len = 128
36+
self.encoder_max_seq_len = 128
3637
self.rope_base = 500_000
3738
self.scale_factor = 32
3839

@@ -86,16 +87,26 @@ def setUp(self):
8687
max_seq_len=self.max_seq_len,
8788
)
8889
self.et_mha.load_state_dict(self.tt_mha.state_dict())
90+
8991
# Common inputs.
9092
seq_len = 10
9193
self.x = torch.randn(1, seq_len, self.embed_dim)
94+
self.y = torch.randn(1, seq_len, self.embed_dim)
9295
self.input_pos = torch.arange(seq_len).unsqueeze(0) # shape [1, seq_len]
93-
seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
94-
self.dynamic_shapes = (
95-
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
96-
{0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
97-
{0: torch.export.Dim.STATIC, 1: seq_len_dim},
98-
)
96+
self.seq_len_dim = torch.export.Dim("seq_len", min=1, max=self.max_seq_len)
97+
self.dynamic_shapes = {
98+
"x": {
99+
0: torch.export.Dim.STATIC,
100+
1: self.seq_len_dim,
101+
2: torch.export.Dim.STATIC,
102+
},
103+
"y": {
104+
0: torch.export.Dim.STATIC,
105+
1: self.seq_len_dim,
106+
2: torch.export.Dim.STATIC,
107+
},
108+
"input_pos": {0: torch.export.Dim.STATIC, 1: self.seq_len_dim},
109+
}
99110
self.causal_mask = torch.tril(
100111
torch.ones(
101112
size=(self.max_seq_len, self.max_seq_len),
@@ -110,8 +121,8 @@ def test_attention_eager(self):
110121
assert_close(et_res, tt_res)
111122

112123
# test with kv cache
113-
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
114-
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
124+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
125+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
115126

116127
et_res = self.et_mha(self.x, self.x) # Self attention.
117128
tt_res = self.tt_mha(self.x, self.x) # Self attention.
@@ -144,12 +155,12 @@ def test_attention_export(self):
144155
# Self attention.
145156

146157
# test with kv cache
147-
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
148-
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
158+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
159+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
149160
with torch.no_grad():
150161
et_mha_ep = torch.export.export(
151162
self.et_mha,
152-
(self.x, self.x),
163+
(self.x, self.y),
153164
kwargs={"input_pos": self.input_pos},
154165
dynamic_shapes=self.dynamic_shapes,
155166
strict=True,
@@ -166,8 +177,8 @@ def test_attention_aoti(self):
166177
# Self attention.
167178

168179
# test with kv cache
169-
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
170-
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
180+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
181+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
171182
with torch.no_grad():
172183
so = torch._export.aot_compile(
173184
self.et_mha,
@@ -189,13 +200,13 @@ def test_attention_aoti(self):
189200

190201
def test_attention_executorch(self):
191202
# Self attention.
192-
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
193-
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
203+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
204+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
194205

195206
with torch.no_grad():
196207
et_mha_ep = torch.export.export(
197208
self.et_mha,
198-
(self.x, self.x),
209+
(self.x, self.y),
199210
kwargs={"input_pos": self.input_pos},
200211
dynamic_shapes=self.dynamic_shapes,
201212
strict=True,
@@ -222,22 +233,18 @@ def test_attention_executorch(self):
222233

223234
def test_attention_torch_cond_eager(self):
224235
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
225-
# For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
236+
# For the first run of MHA we provide `y` but for the second run it will be a tensor full of nan.
226237
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
227238
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
228239

229240
mask = self.causal_mask[self.input_pos, :]
230241
# First run.
231-
et_res = self.et_mha(
232-
self.x, self.x, mask=mask, input_pos=self.input_pos
233-
) # Self attention with input pos.
234-
tt_res = self.tt_mha(
235-
self.x, self.x, mask=mask, input_pos=self.input_pos
236-
) # Self attention with input pos.
242+
et_res = self.et_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)
243+
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)
237244

238245
assert_close(et_res, tt_res)
239246

240-
# Second run test kv cache read. Input pos is [10, 11, ..., 19]
247+
# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
241248
next_input_pos = torch.arange(10, 20).unsqueeze(0)
242249

243250
empty_y = torch.full_like(self.x, torch.nan)
@@ -246,3 +253,101 @@ def test_attention_torch_cond_eager(self):
246253
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)
247254

248255
assert_close(et_res, tt_res)
256+
257+
def test_attention_torch_cond_export(self):
258+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
259+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
260+
mask = self.causal_mask[self.input_pos, :]
261+
dynamic_shapes = {
262+
**self.dynamic_shapes,
263+
**{
264+
"mask": {
265+
0: torch.export.Dim.STATIC,
266+
1: self.seq_len_dim,
267+
2: torch.export.Dim.STATIC,
268+
}
269+
},
270+
}
271+
with torch.no_grad():
272+
et_mha_ep = torch.export.export(
273+
self.et_mha,
274+
(self.x, self.y),
275+
kwargs={
276+
"mask": mask,
277+
"input_pos": self.input_pos,
278+
},
279+
dynamic_shapes=dynamic_shapes,
280+
strict=True,
281+
)
282+
283+
# First run.
284+
et_res = et_mha_ep.module()(self.x, self.y, mask=mask, input_pos=self.input_pos)
285+
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)
286+
287+
assert_close(et_res, tt_res)
288+
289+
# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
290+
next_input_pos = torch.arange(10, 20).unsqueeze(0)
291+
empty_y = torch.full_like(self.y, torch.nan)
292+
mask = self.causal_mask[next_input_pos, :]
293+
et_res = et_mha_ep.module()(
294+
self.x, empty_y, mask=mask, input_pos=next_input_pos
295+
)
296+
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)
297+
298+
assert_close(et_res, tt_res)
299+
300+
def test_attention_torch_cond_executorch(self):
301+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
302+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
303+
mask = self.causal_mask[self.input_pos, :]
304+
dynamic_shapes = {
305+
**self.dynamic_shapes,
306+
**{
307+
"mask": {
308+
0: torch.export.Dim.STATIC,
309+
1: self.seq_len_dim,
310+
2: torch.export.Dim.STATIC,
311+
}
312+
},
313+
}
314+
with torch.no_grad():
315+
et_mha_ep = torch.export.export(
316+
self.et_mha,
317+
(self.x, self.y),
318+
kwargs={
319+
"mask": mask,
320+
"input_pos": self.input_pos,
321+
},
322+
dynamic_shapes=dynamic_shapes,
323+
strict=True,
324+
)
325+
et_program = to_edge(
326+
et_mha_ep,
327+
compile_config=EdgeCompileConfig(
328+
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
329+
_check_ir_validity=False,
330+
),
331+
).to_executorch(
332+
config=ExecutorchBackendConfig(
333+
passes=[InitializedMutableBufferPass(["cache_pos"])],
334+
)
335+
)
336+
337+
# First run.
338+
runtime = Runtime.get()
339+
program = runtime.load_program(et_program.buffer)
340+
method = program.load_method("forward")
341+
et_res = method.execute((self.x, self.y, mask, self.input_pos))
342+
tt_res = self.tt_mha(self.x, self.y, mask=mask, input_pos=self.input_pos)
343+
344+
assert_close(et_res[0], tt_res)
345+
346+
# Second run tests kv cache read. Input pos is [10, 11, ..., 19]
347+
next_input_pos = torch.arange(10, 20).unsqueeze(0)
348+
empty_y = torch.full_like(self.y, torch.nan)
349+
mask = self.causal_mask[next_input_pos, :]
350+
et_res = method.execute((self.x, empty_y, mask, next_input_pos))
351+
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)
352+
353+
assert_close(et_res[0], tt_res)

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@
257257

258258
- op: mean.out
259259

260+
- op: mean.dtype_out
261+
260262
- op: min.dim_min
261263

262264
- op: min.unary_out

kernels/portable/cpu/op_mean.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ Tensor& mean_dim_out(
6666
return out;
6767
}
6868

69+
Tensor& mean_dtype_out(
70+
KernelRuntimeContext& ctx,
71+
const Tensor& in,
72+
optional<ScalarType> dtype,
73+
Tensor& out) {
74+
return mean_dim_out(ctx, in, ArrayRef<int64_t>(), false, dtype, out);
75+
}
76+
6977
} // namespace native
7078
} // namespace executor
7179
} // namespace torch

kernels/portable/cpu/util/reduce_util.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ bool check_mean_dim_args(
386386
check_reduction_args(in, dim_list, keepdim, dtype, out));
387387

388388
if (dtype) {
389+
ET_LOG(Info, "dtype is %hhd", static_cast<int8_t>(dtype.value()));
389390
ET_LOG_AND_RETURN_IF_FALSE(torch::executor::isFloatingType(dtype.value()));
390391
ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value());
391392
} else {

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,11 @@
577577
- arg_meta: null
578578
kernel_name: torch::executor::mean_dim_out
579579

580+
- op: mean.dtype_out
581+
kernels:
582+
- arg_meta: null
583+
kernel_name: torch::executor::mean_dtype_out
584+
580585
- op: min.dim_min
581586
kernels:
582587
- arg_meta: null

kernels/test/op_mean_test.cpp

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
1010
#include <executorch/kernels/test/TestUtil.h>
1111
#include <executorch/kernels/test/supported_features.h>
12-
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/error.h>
1313
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
1414
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
1515
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
@@ -22,6 +22,7 @@ using exec_aten::ArrayRef;
2222
using exec_aten::optional;
2323
using exec_aten::ScalarType;
2424
using exec_aten::Tensor;
25+
using executorch::runtime::Error;
2526
using torch::executor::testing::TensorFactory;
2627

2728
class OpMeanOutTest : public OperatorTest {
@@ -36,6 +37,13 @@ class OpMeanOutTest : public OperatorTest {
3637
context_, self, dim, keepdim, dtype, out);
3738
}
3839

40+
Tensor& op_mean_dtype_out(
41+
const Tensor& self,
42+
optional<ScalarType> dtype,
43+
Tensor& out) {
44+
return torch::executor::aten::mean_outf(context_, self, dtype, out);
45+
}
46+
3947
template <ScalarType IN_DTYPE, ScalarType OUT_DTYPE>
4048
void test_mean_dim_out_invalid_dimensions() {
4149
TensorFactory<IN_DTYPE> tf_in;
@@ -466,3 +474,68 @@ TEST_F(OpMeanOutTest, DynamicShapeUnbound) {
466474
op_mean_out(x, ArrayRef<int64_t>{1}, false, ScalarType::Float, out);
467475
EXPECT_TENSOR_CLOSE(out, expected_result);
468476
}
477+
478+
TEST_F(OpMeanOutTest, DTypeOutFloatValid) {
479+
TensorFactory<ScalarType::Float> tf;
480+
481+
Tensor x = tf.make(
482+
{10, 10},
483+
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
484+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
485+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
486+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
487+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
488+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
489+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
490+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
491+
Tensor expected_result = tf.make({}, {1.0});
492+
493+
Tensor out = tf.zeros({});
494+
Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out);
495+
EXPECT_TENSOR_CLOSE(out, expected_result);
496+
}
497+
498+
TEST_F(OpMeanOutTest, DTypeOutFloatToBoolInvalid) {
499+
TensorFactory<ScalarType::Float> tf;
500+
501+
Tensor x = tf.make(
502+
{10, 10},
503+
{1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
504+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
505+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
506+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
507+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
508+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
509+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
510+
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
511+
Tensor expected_result = tf.make({}, {1.0});
512+
513+
Tensor out = tf.zeros({});
514+
515+
ET_EXPECT_KERNEL_FAILURE(
516+
context_, op_mean_dtype_out(x, ScalarType::Bool, out));
517+
}
518+
519+
TEST_F(OpMeanOutTest, DTypeOutFloatInfinity) {
520+
TensorFactory<ScalarType::Float> tf;
521+
522+
Tensor x = tf.make({2, 1}, {INFINITY, INFINITY});
523+
Tensor expected_result = tf.make({}, {INFINITY});
524+
525+
Tensor out = tf.zeros({});
526+
527+
Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out);
528+
EXPECT_TENSOR_CLOSE(out, expected_result);
529+
}
530+
531+
TEST_F(OpMeanOutTest, DTypeOutFloatNAN) {
532+
TensorFactory<ScalarType::Float> tf;
533+
534+
Tensor x = tf.make({2, 1}, {NAN, INFINITY});
535+
Tensor expected_result = tf.make({}, {NAN});
536+
537+
Tensor out = tf.zeros({});
538+
539+
Tensor ret = op_mean_dtype_out(x, ScalarType::Float, out);
540+
EXPECT_TENSOR_CLOSE(out, expected_result);
541+
}

0 commit comments

Comments
 (0)