Skip to content

Commit a311e9c

Browse files
issue/591 infinicore.narrow
1 parent 2286cf7 commit a311e9c

File tree

5 files changed

+108
-6
lines changed

5 files changed

+108
-6
lines changed

python/infinicore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from infinicore.ops.attention import attention
3232
from infinicore.ops.matmul import matmul
3333
from infinicore.ops.mul import mul
34+
from infinicore.ops.narrow import narrow
3435
from infinicore.ops.rearrange import rearrange
3536
from infinicore.tensor import (
3637
Tensor,
@@ -78,6 +79,7 @@
7879
"attention",
7980
"matmul",
8081
"mul",
82+
"narrow",
8183
"rearrange",
8284
"empty",
8385
"empty_like",

python/infinicore/ops/narrow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from infinicore.tensor import Tensor
2+
3+
4+
def narrow(input: Tensor, dim: int, start: int, length: int) -> Tensor:
5+
return Tensor(input._underlying.narrow(dim, start, length))

python/infinicore/tensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def numel(self):
5252
def is_contiguous(self):
5353
return self._underlying.is_contiguous()
5454

55-
def is_is_pinned(self):
56-
return self._underlying.is_is_pinned()
55+
def is_pinned(self):
56+
return self._underlying.is_pinned()
5757

5858
def copy_(self, src):
5959
self._underlying.copy_(src._underlying)
@@ -63,12 +63,12 @@ def to(self, *args, **kwargs):
6363
self._underlying.to(*tuple(arg._underlying for arg in args), **kwargs)
6464
)
6565

66-
def as_strided(self, size, stride):
67-
return Tensor(self._underlying.as_strided(size, stride))
68-
6966
def contiguous(self):
7067
return Tensor(self._underlying.contiguous())
7168

69+
def as_strided(self, size, stride):
70+
return Tensor(self._underlying.as_strided(size, stride))
71+
7272
def permute(self, dims):
7373
return Tensor(self._underlying.permute(dims))
7474

src/infinicore/pybind11/tensor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ inline void bind(py::module &m) {
3232
.def("to", [](const Tensor &tensor, const Device &device) { return tensor->to(device); })
3333
.def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); })
3434
.def("contiguous", [](const Tensor &tensor) { return tensor->contiguous(); })
35-
35+
.def("narrow", [](const Tensor &tensor, std::size_t dim, std::size_t start, std::size_t length) { return tensor->narrow({{dim, start, length}}); })
3636
.def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); })
3737
.def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); });
3838

test/infinicore/tensor/narrow.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import sys
2+
import os
3+
4+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
5+
6+
import torch
7+
import infinicore
8+
from framework.base import BaseOperatorTest, TensorSpec, TestCase
9+
from framework.runner import GenericTestRunner
10+
from framework.utils import is_broadcast
11+
12+
# ==============================================================================
13+
# Operator-specific configuration
14+
# ==============================================================================
15+
16+
# Test cases format: (shape, dim, start, length)
17+
_TEST_CASES_DATA = [
18+
# Basic cases
19+
((2, 4), 0, 0, 1),
20+
((2, 4), 1, 1, 1),
21+
((5, 3, 2), 1, 0, 3),
22+
((5, 3, 2), 0, 1, 3),
23+
((4, 4, 1024, 32), 2, 1023, 1),
24+
]
25+
26+
# Tolerance configuration
27+
_TOLERANCE_MAP = {
28+
infinicore.float16: {"atol": 0, "rtol": 0},
29+
infinicore.float32: {"atol": 0, "rtol": 0},
30+
infinicore.bfloat16: {"atol": 0, "rtol": 0},
31+
}
32+
33+
# Data types to test
34+
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
35+
36+
37+
def parse_test_cases():
38+
"""
39+
Parse test case data and return list of TestCase objects for all operation types.
40+
Each test case contains all necessary information for execution and validation.
41+
"""
42+
test_cases = []
43+
44+
for data in _TEST_CASES_DATA:
45+
shape = data[0]
46+
dim = data[1]
47+
start = data[2]
48+
length = data[3]
49+
50+
# Generate test cases for all data types
51+
for dtype in _TENSOR_DTYPES:
52+
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0})
53+
54+
# Create typed tensor specs
55+
a_spec = TensorSpec.from_tensor(shape, None, dtype)
56+
test_cases.append(
57+
TestCase(
58+
inputs=[a_spec, dim, start, length],
59+
kwargs={},
60+
output_spec=None,
61+
comparison_target=None, # Compare output
62+
tolerance=tolerance,
63+
description=f"Narrow",
64+
)
65+
)
66+
67+
return test_cases
68+
69+
70+
class OpTest(BaseOperatorTest):
71+
"""Narrow operator test with simplified implementation"""
72+
73+
def __init__(self):
74+
super().__init__("Narrow")
75+
76+
def get_test_cases(self):
77+
return parse_test_cases()
78+
79+
def torch_operator(self, *args, **kwargs):
80+
"""PyTorch narrow implementation"""
81+
return torch.narrow(*args, **kwargs)
82+
83+
def infinicore_operator(self, *args, **kwargs):
84+
"""InfiniCore narrow implementation"""
85+
return infinicore.narrow(*args, **kwargs)
86+
87+
88+
def main():
89+
"""Main entry point"""
90+
runner = GenericTestRunner(OpTest)
91+
runner.run_and_exit()
92+
93+
94+
if __name__ == "__main__":
95+
main()

0 commit comments

Comments
 (0)