Skip to content

Commit 9a4fb9b

Browse files
committed
Merge from main && resolve conflict && format code
2 parents 32bd2f8 + 0ead67f commit 9a4fb9b

File tree

18 files changed

+426
-47
lines changed

18 files changed

+426
-47
lines changed

include/infinicore/tensor.hpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,19 @@ class TensorImpl : public std::enable_shared_from_this<TensorImpl> {
168168
/// View APIs
169169
///
170170

171+
/**
172+
* Returns a new tensor with a dimension of size one removed at the specified position.
173+
* Throws runtime_error if the dimension to be removed is not of size 1.
174+
*
175+
* @param dim The dimension index to remove
176+
* @return A new tensor with the removed dimension
177+
*
178+
* Example:
179+
* // For a 3D tensor with shape [1, 3, 4], squeeze at dim 0 results in shape [3, 4]
180+
* tensor->squeeze(0);
181+
*/
182+
Tensor squeeze(size_t dim) const;
183+
171184
/**
172185
* Returns a new tensor with a dimension of size one inserted at the specified position.
173186
* The returned tensor shares the same underlying storage with the original tensor.

python/infinicore/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from infinicore.ops.mul import mul
4646
from infinicore.ops.narrow import narrow
4747
from infinicore.ops.rearrange import rearrange
48+
from infinicore.ops.squeeze import squeeze
49+
from infinicore.ops.unsqueeze import unsqueeze
4850
from infinicore.tensor import (
4951
Tensor,
5052
empty,
@@ -104,6 +106,8 @@
104106
"matmul",
105107
"mul",
106108
"narrow",
109+
"squeeze",
110+
"unsqueeze",
107111
"rearrange",
108112
"empty",
109113
"empty_like",

python/infinicore/nn/functional/rope.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,6 @@ def rope(
2020
) -> Tensor:
2121
r"""Rotary Position Embedding(RoPE)."""
2222

23-
bs, seq_len, num_heads, head_dim = x.shape
24-
x_stride = x.stride()
25-
assert seq_len * x_stride[1] == x_stride[0], (
26-
"x need to be continuous in dim=0 and dim=1"
27-
)
28-
29-
x = x.view((bs * seq_len, num_heads, head_dim))
30-
bs, num = pos_ids.shape
31-
pos_ids = pos_ids.view((bs * num,))
32-
3323
if out is None:
3424
return Tensor(
3525
_infinicore.rope(
@@ -39,9 +29,8 @@ def rope(
3929
cos_table._underlying,
4030
algo,
4131
)
42-
).view((bs, seq_len, num_heads, head_dim))
32+
)
4333

44-
out = out.view((bs * seq_len, num_heads, head_dim))
4534
_infinicore.rope_(
4635
out._underlying,
4736
x._underlying,
@@ -50,4 +39,4 @@ def rope(
5039
cos_table._underlying,
5140
algo,
5241
)
53-
return out.view((bs, seq_len, num_heads, head_dim))
42+
return out

python/infinicore/ops/squeeze.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from infinicore.tensor import Tensor
2+
3+
4+
def squeeze(input: Tensor, dim: int) -> Tensor:
5+
return Tensor(input._underlying.squeeze(dim))

python/infinicore/ops/unsqueeze.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from infinicore.tensor import Tensor
2+
3+
4+
def unsqueeze(input: Tensor, dim: int) -> Tensor:
5+
return Tensor(input._underlying.unsqueeze(dim))

python/infinicore/tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def permute(self, dims):
9292
def view(self, shape):
9393
return Tensor(self._underlying.view(shape))
9494

95+
def squeeze(self, dim):
96+
return infinicore.squeeze(self, dim)
97+
98+
def unsqueeze(self, dim):
99+
return infinicore.unsqueeze(self, dim)
100+
95101
def debug(self, filename=None):
96102
"""Print tensor data or save to file for debugging
97103

src/infinicore/pybind11/tensor.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,27 @@ inline void bind(py::module &m) {
1616
.def_property_readonly("ndim", [](const Tensor &tensor) { return tensor->ndim(); })
1717
.def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); })
1818
.def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); })
19-
2019
.def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<std::uintptr_t>(tensor->data()); })
2120
.def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); })
2221
.def("stride", [](const Tensor &tensor, std::size_t dim) { return tensor->stride(dim); })
2322
.def("numel", [](const Tensor &tensor) { return tensor->numel(); })
24-
2523
.def("is_contiguous", [](const Tensor &tensor) { return tensor->is_contiguous(); })
2624
.def("is_pinned", [](const Tensor &tensor) { return tensor->is_pinned(); })
2725
.def("info", [](const Tensor &tensor) { return tensor->info(); })
26+
2827
.def("debug", [](const Tensor &tensor) { return tensor->debug(); })
2928
.def("debug", [](const Tensor &tensor, const std::string &filename) { return tensor->debug(filename); })
3029

3130
.def("copy_", [](Tensor &tensor, const Tensor &other) { tensor->copy_from(other); })
3231
.def("to", [](const Tensor &tensor, const Device &device) { return tensor->to(device); })
33-
.def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); })
3432
.def("contiguous", [](const Tensor &tensor) { return tensor->contiguous(); })
33+
34+
.def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); })
3535
.def("narrow", [](const Tensor &tensor, std::size_t dim, std::size_t start, std::size_t length) { return tensor->narrow({{dim, start, length}}); })
3636
.def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); })
37-
.def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); });
37+
.def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); })
38+
.def("unsqueeze", [](const Tensor &tensor, std::size_t dim) { return tensor->unsqueeze(dim); })
39+
.def("squeeze", [](const Tensor &tensor, std::size_t dim) { return tensor->squeeze(dim); });
3840

3941
m.def("empty", &Tensor::empty,
4042
py::arg("shape"),

src/infinicore/tensor/view.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,23 @@
66
#include <stdexcept>
77

88
namespace infinicore {
9+
Tensor TensorImpl::squeeze(size_t dim) const {
10+
// Create new shape with dimension of size one removed at dim
11+
if (meta_.shape[dim] != 1) {
12+
spdlog::error("Dimension {} is not of size 1 for squeeze operation on {}.", dim, this->info());
13+
throw std::runtime_error("Invalid squeeze operation on tensor.");
14+
}
15+
Shape new_shape = meta_.shape;
16+
new_shape.erase(new_shape.begin() + dim);
17+
Strides new_strides = meta_.strides;
18+
new_strides.erase(new_strides.begin() + dim);
19+
20+
auto tensor_impl = std::make_shared<TensorImpl>(new_shape, new_strides, meta_.dtype);
21+
tensor_impl->data_ = data_;
22+
23+
return Tensor(tensor_impl);
24+
}
25+
926
Tensor TensorImpl::unsqueeze(size_t dim) const {
1027
// Create new shape with dimension of size one inserted at dim
1128
Shape new_shape = meta_.shape;

src/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef INFINIUTILS_H
22
#define INFINIUTILS_H
33

4+
#include "infinicore.h"
45
#include "utils/custom_types.h"
56
#include "utils/rearrange.h"
67

src/utils/check.h

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,19 @@
33
#include <iostream>
44
#include <tuple>
55

6+
#include "../utils.h"
67
#include "infini_status_string.h"
78

9+
#define CHECK_OR_DO(CONDITION, ACTION) \
10+
do { \
11+
if (!(CONDITION)) { \
12+
std::cerr << "Check Failed: `(" << #CONDITION << ")` is False" \
13+
<< " from " << __func__ \
14+
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \
15+
{ ACTION; } \
16+
} \
17+
} while (0)
18+
819
#define CHECK_OR_RETURN(CONDITION, ERROR) \
920
do { \
1021
if (!(CONDITION)) { \
@@ -33,17 +44,19 @@
3344
std::cerr << "Error: " << infini_status_string(api_result_) << std::endl; \
3445
return api_result_)
3546

36-
#define CHECK_DTYPE(DT, ...) \
37-
do { \
38-
auto found_supported_dtype = false; \
39-
for (auto dt : {__VA_ARGS__}) { \
40-
if (dt == DT) { \
41-
found_supported_dtype = true; \
42-
break; \
43-
} \
44-
} \
45-
CHECK_API_OR(found_supported_dtype, true, \
46-
return INFINI_STATUS_BAD_TENSOR_DTYPE); \
47+
#define CHECK_DTYPE(DT, ...) \
48+
do { \
49+
auto dtype_is_supported = false; \
50+
for (auto dt : {__VA_ARGS__}) { \
51+
if (dt == DT) { \
52+
dtype_is_supported = true; \
53+
break; \
54+
} \
55+
} \
56+
CHECK_OR_DO(dtype_is_supported, \
57+
{ std::cerr << "Unsupported dtype: " << \
58+
infiniDtypeToString(DT) << ". "; \
59+
return INFINI_STATUS_BAD_TENSOR_DTYPE; }); \
4760
} while (0)
4861

4962
#define CHECK_DTYPE_ANY_INT(DT) \

0 commit comments

Comments
 (0)