Skip to content

Commit ffdb0c6

Browse files
Improve section array map (#437)
* Refactor section array map Improve const handling for the section array map * Update dependencies * Improve module constructor * Fix type hint * Improve stub generator
1 parent 63938db commit ffdb0c6

File tree

10 files changed

+89
-39
lines changed

10 files changed

+89
-39
lines changed

requirements.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
AMULET_COMPILER_TARGET_REQUIREMENT = "==2.0"
66

7-
PYBIND11_REQUIREMENT = "==3.0.0"
8-
AMULET_PYBIND11_EXTENSIONS_REQUIREMENT = "~=1.2.0.0a0"
7+
PYBIND11_REQUIREMENT = "==3.0.1"
8+
AMULET_PYBIND11_EXTENSIONS_REQUIREMENT = "~=1.2.0.0a2"
99
AMULET_IO_REQUIREMENT = "~=1.0"
10-
AMULET_UTILS_REQUIREMENT = "~=1.1.3.0a1"
10+
AMULET_UTILS_REQUIREMENT = "~=1.1.3.0a6"
1111
AMULET_ZLIB_REQUIREMENT = "~=1.0.8.0a0"
12-
AMULET_NBT_REQUIREMENT = "~=5.0.2.0a0"
12+
AMULET_NBT_REQUIREMENT = "~=5.0.2.0a2"
1313
NUMPY_REQUIREMENT = "~=2.0"
1414

1515
if os.environ.get("AMULET_PYBIND11_EXTENSIONS_REQUIREMENT", None):

src/amulet/core/_amulet_core.py.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,5 @@ void init_module(py::module m)
3434

3535
PYBIND11_MODULE(_amulet_core, m)
3636
{
37-
py::options options;
38-
options.disable_function_signatures();
39-
m.def("init", &init_module, py::doc("init(arg0: types.ModuleType) -> None"));
40-
options.enable_function_signatures();
37+
m.def("init", &init_module, py::arg("m"));
4138
}

src/amulet/core/_amulet_core.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ import types
44

55
__all__: list[str] = ["init"]
66

7-
def init(arg0: types.ModuleType) -> None: ...
7+
def init(m: types.ModuleType) -> None: ...

src/amulet/core/chunk/component/section_array_map.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,8 @@ void SectionArrayMap::serialise(BinaryWriter& writer) const
153153
get_default_array());
154154

155155
// Write arrays
156-
const auto& arrays = get_arrays();
157-
writer.write_numeric<std::uint64_t>(arrays.size());
158-
for (const auto& [cy, arr] : arrays) {
156+
writer.write_numeric<std::uint64_t>(_arrays.size());
157+
for (const auto& [cy, arr] : _arrays) {
159158
writer.write_numeric<std::int64_t>(cy);
160159
arr->serialise(writer);
161160
}

src/amulet/core/chunk/component/section_array_map.hpp

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include <amulet/io/binary_reader.hpp>
1515
#include <amulet/io/binary_writer.hpp>
1616

17+
#include <amulet/utils/view/map_view.hpp>
18+
1719
#include <amulet/core/dll.hpp>
1820

1921
namespace Amulet {
@@ -41,9 +43,10 @@ class IndexArray3D {
4143
AMULET_CORE_EXPORT static IndexArray3D deserialise(BinaryReader&);
4244

4345
const SectionShape& get_shape() const { return _shape; }
44-
const size_t& get_size() const { return _size; }
45-
std::uint32_t* get_buffer() const { return _buffer; }
46-
std::span<std::uint32_t> get_span() const { return { _buffer, _size }; }
46+
size_t get_size() const { return _size; }
47+
std::uint32_t* get_buffer() { return _buffer; }
48+
const std::uint32_t* get_buffer() const { return _buffer; }
49+
std::span<std::uint32_t> get_span() { return { _buffer, _size }; }
4750
};
4851

4952
class SectionArrayMap {
@@ -84,18 +87,43 @@ class SectionArrayMap {
8487

8588
const SectionShape& get_array_shape() const { return _array_shape; }
8689

87-
std::variant<std::uint32_t, std::shared_ptr<IndexArray3D>> get_default_array() const
90+
std::variant<std::uint32_t, std::shared_ptr<const IndexArray3D>> get_default_array() const
91+
{
92+
return std::visit(
93+
[](auto&& arg) -> std::variant<std::uint32_t, std::shared_ptr<const IndexArray3D>> {
94+
return arg;
95+
},
96+
_default_array);
97+
}
98+
99+
std::variant<std::uint32_t, std::shared_ptr<IndexArray3D>> get_default_array()
88100
{
89101
return _default_array;
90102
}
91103

92-
void set_default_array(std::variant<std::uint32_t, std::shared_ptr<IndexArray3D>> default_array)
104+
void set_default_array(std::uint32_t default_array)
93105
{
94-
validate_array_shape(default_array);
106+
_default_array = default_array;
107+
}
108+
109+
void set_default_array(std::shared_ptr<IndexArray3D> default_array)
110+
{
111+
validate_array_shape(*default_array);
95112
_default_array = std::move(default_array);
96113
}
97114

98-
const std::unordered_map<std::int64_t, std::shared_ptr<IndexArray3D>>& get_arrays() const
115+
void set_default_array(const IndexArray3D& default_array)
116+
{
117+
validate_array_shape(default_array);
118+
_default_array = std::make_shared<IndexArray3D>(default_array);
119+
}
120+
121+
const std::unordered_map<std::int64_t, std::shared_ptr<IndexArray3D>>& get_arrays()
122+
{
123+
return _arrays;
124+
}
125+
126+
const MapView<std::unordered_map<std::int64_t, std::shared_ptr<IndexArray3D>>, std::shared_ptr<const IndexArray3D>> get_arrays() const
99127
{
100128
return _arrays;
101129
}
@@ -107,17 +135,38 @@ class SectionArrayMap {
107135
return _arrays.contains(cy);
108136
}
109137

110-
std::shared_ptr<IndexArray3D> get_section(std::int64_t cy) const
138+
std::shared_ptr<IndexArray3D> get_section(std::int64_t cy)
111139
{
112140
return _arrays.at(cy);
113141
}
114142

143+
std::shared_ptr<const IndexArray3D> get_section(std::int64_t cy) const
144+
{
145+
return _arrays.at(cy);
146+
}
147+
148+
IndexArray3D& get_section_ref(std::int64_t cy)
149+
{
150+
return *_arrays.at(cy);
151+
}
152+
153+
const IndexArray3D& get_section_ref(std::int64_t cy) const
154+
{
155+
return *_arrays.at(cy);
156+
}
157+
115158
void set_section(std::int64_t cy, std::shared_ptr<IndexArray3D> section)
116159
{
117160
validate_array_shape(*section);
118161
_arrays.insert_or_assign(cy, std::move(section));
119162
}
120163

164+
void set_section(std::int64_t cy, const IndexArray3D& section)
165+
{
166+
validate_array_shape(section);
167+
_arrays.insert_or_assign(cy, std::make_shared<IndexArray3D>(section));
168+
}
169+
121170
AMULET_CORE_EXPORT void populate_section(std::int64_t cy);
122171

123172
void del_section(std::int64_t cy)

src/amulet/core/chunk/component/section_array_map.py.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ py::module init_section_array_map(py::module m_parent)
144144
py::object asarray = py::module::import("numpy").attr("asarray");
145145
SectionArrayMap.def_property(
146146
"default_array",
147-
[asarray](const Amulet::SectionArrayMap& self) {
147+
[asarray](Amulet::SectionArrayMap& self) {
148148
return std::visit([asarray](auto&& arg) -> std::variant<std::uint32_t, py::array> {
149149
using T = std::decay_t<decltype(arg)>;
150150
if constexpr (std::is_same_v<T, std::uint32_t>) {
@@ -202,7 +202,7 @@ py::module init_section_array_map(py::module m_parent)
202202
&Amulet::SectionArrayMap::del_section);
203203
SectionArrayMap.def(
204204
"__getitem__",
205-
[asarray](const Amulet::SectionArrayMap& self, std::int64_t cy) -> pyext::PyObjectCpp<pyext::numpy::array_t<std::uint32_t>> {
205+
[asarray](Amulet::SectionArrayMap& self, std::int64_t cy) -> pyext::PyObjectCpp<pyext::numpy::array_t<std::uint32_t>> {
206206
try {
207207
return asarray(py::cast(self.get_section(cy)));
208208
} catch (const std::out_of_range&) {
@@ -214,7 +214,7 @@ py::module init_section_array_map(py::module m_parent)
214214
&Amulet::SectionArrayMap::get_size);
215215
SectionArrayMap.def(
216216
"__iter__",
217-
[](const Amulet::SectionArrayMap& self) -> pyext::collections::Iterator<std::int64_t> {
217+
[](Amulet::SectionArrayMap& self) -> pyext::collections::Iterator<std::int64_t> {
218218
return pyext::make_iterator(pyext::detail::MapIterator(self.get_arrays()));
219219
},
220220
py::keep_alive<0, 1>());

src/amulet/core/chunk/component/section_array_map.pyi

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,8 @@ class SectionArrayMap:
9494
def update(
9595
self,
9696
other: (
97-
collections.abc.Mapping[
98-
typing.SupportsInt, numpy.typing.NDArray[numpy.uint32]
99-
]
100-
| collections.abc.Iterable[
101-
tuple[typing.SupportsInt, numpy.typing.NDArray[numpy.uint32]]
102-
]
97+
collections.abc.Mapping[int, numpy.typing.NDArray[numpy.uint32]]
98+
| collections.abc.Iterable[tuple[int, numpy.typing.NDArray[numpy.uint32]]]
10399
) = (),
104100
**kwargs: numpy.typing.NDArray[numpy.uint32],
105101
) -> None: ...

tests/test_amulet_core/_test_amulet_core.py.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,5 @@ void init_module(py::module m){
3131
}
3232

3333
PYBIND11_MODULE(_test_amulet_core, m) {
34-
py::options options;
35-
options.disable_function_signatures();
36-
m.def("init", &init_module, py::doc("init(arg0: types.ModuleType) -> None"));
37-
options.enable_function_signatures();
34+
m.def("init", &init_module, py::arg("m"));
3835
}

tests/test_amulet_core/_test_amulet_core.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ import types
44

55
__all__: list[str] = ["init"]
66

7-
def init(arg0: types.ModuleType) -> None: ...
7+
def init(m: types.ModuleType) -> None: ...

tools/generate_pybind_stubs.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,28 @@
77
import pybind11_stubgen
88
from pybind11_stubgen.structs import Identifier
99
from pybind11_stubgen.parser.mixins.filter import FilterClassMembers
10-
from pybind11_stubgen import main as pybind11_stubgen_main
10+
11+
12+
ForwardRefPattern = re.compile(r"ForwardRef\('(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*)'\)")
13+
14+
QuotePattern = re.compile(r"'(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*)'")
15+
16+
17+
def fix_value(value: str) -> str:
18+
value = value.replace("NoneType", "None")
19+
value = ForwardRefPattern.sub(lambda match: match.group("variable"), value)
20+
value = QuotePattern.sub(lambda match: match.group("variable"), value)
21+
return value
22+
1123

1224
UnionPattern = re.compile(
13-
r"^(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*): types\.UnionType\s*#\s*value = (?P<value>.*)$",
25+
r"^(?P<variable>[a-zA-Z_][a-zA-Z0-9_]*): (types\.UnionType|typing\._UnionGenericAlias)\s*#\s*value = (?P<value>.*)$",
1426
flags=re.MULTILINE,
1527
)
1628

1729

1830
def union_sub_func(match: re.Match[str]) -> str:
19-
return f'{match.group("variable")}: typing.TypeAlias = {match.group("value")}'
31+
return f'{match.group("variable")}: typing.TypeAlias = {fix_value(match.group("value"))}'
2032

2133

2234
ClassVarUnionPattern = re.compile(
@@ -26,7 +38,7 @@ def union_sub_func(match: re.Match[str]) -> str:
2638

2739

2840
def class_var_union_sub_func(match: re.Match) -> str:
29-
return f'{match.group("variable")}: typing.TypeAlias = {match.group("value")}'
41+
return f'{match.group("variable")}: typing.TypeAlias = {fix_value(match.group("value"))}'
3042

3143

3244
VersionPattern = re.compile(r"(?P<var>[a-zA-Z0-9_].*): str = '.*?'")
@@ -87,7 +99,7 @@ def eq_sub_func(match: re.Match[str]) -> str:
8799

88100

89101
def generic_alias_sub_func(match: re.Match) -> str:
90-
return f"{match.group('variable')}: typing.TypeAlias = {match.group('value')}"
102+
return f'{match.group("variable")}: typing.TypeAlias = {fix_value(match.group("value"))}'
91103

92104

93105
def get_module_path(name: str) -> str:

0 commit comments

Comments
 (0)