Skip to content

Commit 4805551

Browse files
committed
feat: Add ArrowDataFrame.from_numpy
1 parent b333351 commit 4805551

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,14 @@
5353
from narwhals._arrow.typing import Mask # type: ignore[attr-defined]
5454
from narwhals._arrow.typing import Order # type: ignore[attr-defined]
5555
from narwhals.dtypes import DType
56+
from narwhals.schema import Schema
5657
from narwhals.typing import CompliantDataFrame
5758
from narwhals.typing import CompliantLazyFrame
5859
from narwhals.typing import SizeUnit
5960
from narwhals.typing import _1DArray
6061
from narwhals.typing import _2DArray
6162
from narwhals.utils import Version
63+
from narwhals.utils import _FullContext
6264

6365
JoinType: TypeAlias = Literal[
6466
"left semi",
@@ -91,6 +93,40 @@ def __init__(
9193
self._version = version
9294
validate_backend_version(self._implementation, self._backend_version)
9395

96+
@classmethod
97+
def from_numpy(
98+
cls,
99+
data: _2DArray,
100+
/,
101+
*,
102+
context: _FullContext,
103+
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
104+
) -> Self:
105+
from narwhals.schema import Schema
106+
107+
arrays = [pa.array(val) for val in data.T]
108+
if isinstance(schema, (Mapping, Schema)):
109+
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
110+
elif is_sequence_but_not_str(schema):
111+
native = pa.Table.from_arrays(arrays, names=list(schema))
112+
elif schema is None:
113+
native = pa.Table.from_arrays(
114+
arrays, names=[f"column_{x}" for x in range(data.shape[1])]
115+
)
116+
else:
117+
msg = (
118+
"`schema` is expected to be one of the following types: "
119+
"Mapping[str, DType] | Schema | Sequence[str]. "
120+
f"Got {type(schema)}."
121+
)
122+
raise TypeError(msg)
123+
return cls(
124+
native,
125+
backend_version=context._backend_version,
126+
version=context._version,
127+
validate_column_names=True,
128+
)
129+
94130
def __narwhals_namespace__(self: Self) -> ArrowNamespace:
95131
from narwhals._arrow.namespace import ArrowNamespace
96132

0 commit comments

Comments
 (0)