Skip to content

Commit 4aab20f

Browse files
committed
Add making JsTensor from iterator
1 parent 7de34b9 commit 4aab20f

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

extension/wasm/test/unittests.js

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ describe("Tensor", () => {
3636
tensor.delete();
3737
});
3838

39+
test("fromGenerator", () => {
40+
function* generator() {
41+
yield* [1, 2, 3, 4];
42+
}
43+
const tensor = et.Tensor.fromIter([2, 2], generator());
44+
expect(tensor.data).toEqual(new Float32Array([1, 2, 3, 4]));
45+
expect(tensor.sizes).toEqual([2, 2]);
46+
tensor.delete();
47+
});
48+
3949
test("fromArray wrong size", () => {
4050
expect(() => et.Tensor.fromArray([3, 2], [1, 2, 3, 4])).toThrow();
4151
});

extension/wasm/wasm_bindings.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ inline void assert_valid_numel(
8383
data.size());
8484
}
8585

86+
template <typename T>
87+
std::vector<T> convertJSGeneratorToNumberVector(val generator) {
88+
std::vector<T> data;
89+
while (true) {
90+
val next = generator.call<val>("next");
91+
if (next["done"].as<bool>()) {
92+
break;
93+
}
94+
data.push_back(next["value"].as<T>());
95+
}
96+
return data;
97+
}
98+
8699
class JsTensor {
87100
public:
88101
JsTensor() = delete;
@@ -204,6 +217,49 @@ class JsTensor {
204217
}
205218
}
206219

220+
static std::unique_ptr<JsTensor> from_iter(
221+
val_array<int> sizes,
222+
val_array<val> data,
223+
val type = val::undefined(),
224+
val_array<int> dim_order = val::undefined(),
225+
val_array<int> strides = val::undefined()) {
226+
auto sizes_vec =
227+
convertJSArrayToNumberVector<executorch::aten::SizesType>(sizes);
228+
229+
auto dim_order_vec = dim_order.isUndefined()
230+
? std::vector<executorch::aten::DimOrderType>()
231+
: convertJSArrayToNumberVector<executorch::aten::DimOrderType>(
232+
dim_order);
233+
auto strides_vec = strides.isUndefined()
234+
? std::vector<executorch::aten::StridesType>()
235+
: convertJSArrayToNumberVector<executorch::aten::StridesType>(strides);
236+
237+
// If type is undefined, infer the type from the data.
238+
// Assume it is a Bigint if not Number.
239+
ScalarType scalar_type = type.isUndefined()
240+
? (data["length"].as<size_t>() == 0 || data[0].isNumber()
241+
? ScalarType::Float
242+
: ScalarType::Long)
243+
: type.as<ScalarType>();
244+
switch (scalar_type) {
245+
#define JS_CASE_FROM_ITER_VECTOR_TYPE(T, NAME) \
246+
case ScalarType::NAME: { \
247+
auto data_vec = convertJSGeneratorToNumberVector<T>(data); \
248+
assert_valid_numel(data_vec, sizes_vec); \
249+
TensorPtr tensor = make_tensor_ptr( \
250+
std::move(sizes_vec), \
251+
std::move(data_vec), \
252+
std::move(dim_order_vec), \
253+
std::move(strides_vec), \
254+
ScalarType::NAME); \
255+
return std::make_unique<JsTensor>(std::move(tensor)); \
256+
}
257+
JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_CASE_FROM_ITER_VECTOR_TYPE)
258+
default:
259+
THROW_JS_ERROR(TypeError, "Unsupported Tensor type: %d", scalar_type);
260+
}
261+
}
262+
207263
private:
208264
TensorPtr tensor_;
209265
};
@@ -448,6 +504,7 @@ EMSCRIPTEN_BINDINGS(WasmBindings) {
448504
.class_function("ones", &JsTensor::ones)
449505
.class_function("full", &JsTensor::full)
450506
.class_function("fromArray", &JsTensor::from_array)
507+
.class_function("fromIter", &JsTensor::from_iter)
451508
.property("scalarType", &JsTensor::get_scalar_type)
452509
.property("data", &JsTensor::get_data)
453510
.property("sizes", &JsTensor::get_sizes);

0 commit comments

Comments
 (0)