@@ -83,6 +83,19 @@ inline void assert_valid_numel(
8383 data.size ());
8484}
8585
86+ template <typename T>
87+ std::vector<T> convertJSGeneratorToNumberVector (val generator) {
88+ std::vector<T> data;
89+ while (true ) {
90+ val next = generator.call <val>(" next" );
91+ if (next[" done" ].as <bool >()) {
92+ break ;
93+ }
94+ data.push_back (next[" value" ].as <T>());
95+ }
96+ return data;
97+ }
98+
8699class JsTensor {
87100 public:
88101 JsTensor () = delete ;
@@ -204,6 +217,49 @@ class JsTensor {
204217 }
205218 }
206219
220+ static std::unique_ptr<JsTensor> from_iter (
221+ val_array<int > sizes,
222+ val_array<val> data,
223+ val type = val::undefined(),
224+ val_array<int> dim_order = val::undefined(),
225+ val_array<int> strides = val::undefined()) {
226+ auto sizes_vec =
227+ convertJSArrayToNumberVector<executorch::aten::SizesType>(sizes);
228+
229+ auto dim_order_vec = dim_order.isUndefined ()
230+ ? std::vector<executorch::aten::DimOrderType>()
231+ : convertJSArrayToNumberVector<executorch::aten::DimOrderType>(
232+ dim_order);
233+ auto strides_vec = strides.isUndefined ()
234+ ? std::vector<executorch::aten::StridesType>()
235+ : convertJSArrayToNumberVector<executorch::aten::StridesType>(strides);
236+
237+ // If type is undefined, infer the type from the data.
238+ // Assume it is a Bigint if not Number.
239+ ScalarType scalar_type = type.isUndefined ()
240+ ? (data[" length" ].as <size_t >() == 0 || data[0 ].isNumber ()
241+ ? ScalarType::Float
242+ : ScalarType::Long)
243+ : type.as <ScalarType>();
244+ switch (scalar_type) {
245+ #define JS_CASE_FROM_ITER_VECTOR_TYPE (T, NAME ) \
246+ case ScalarType::NAME: { \
247+ auto data_vec = convertJSGeneratorToNumberVector<T>(data); \
248+ assert_valid_numel (data_vec, sizes_vec); \
249+ TensorPtr tensor = make_tensor_ptr ( \
250+ std::move (sizes_vec), \
251+ std::move (data_vec), \
252+ std::move (dim_order_vec), \
253+ std::move (strides_vec), \
254+ ScalarType::NAME); \
255+ return std::make_unique<JsTensor>(std::move (tensor)); \
256+ }
257+ JS_FORALL_SUPPORTED_TENSOR_TYPES (JS_CASE_FROM_ITER_VECTOR_TYPE)
258+ default :
259+ THROW_JS_ERROR (TypeError, " Unsupported Tensor type: %d" , scalar_type);
260+ }
261+ }
262+
207263 private:
208264 TensorPtr tensor_;
209265};
@@ -448,6 +504,7 @@ EMSCRIPTEN_BINDINGS(WasmBindings) {
448504 .class_function (" ones" , &JsTensor::ones)
449505 .class_function (" full" , &JsTensor::full)
450506 .class_function (" fromArray" , &JsTensor::from_array)
507+ .class_function (" fromIter" , &JsTensor::from_iter)
451508 .property (" scalarType" , &JsTensor::get_scalar_type)
452509 .property (" data" , &JsTensor::get_data)
453510 .property (" sizes" , &JsTensor::get_sizes);
0 commit comments