Skip to content

Commit 2c64769

Browse files
mangelatsLuthaf
authored andcommitted
Add access methods to the generated SoA vectors
1 parent e1716d7 commit 2c64769

File tree

3 files changed

+402
-0
lines changed

3 files changed

+402
-0
lines changed

soa-derive-internal/src/index.rs

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
use proc_macro2::TokenStream;
2+
use quote::quote;
3+
4+
use crate::input::Input;
5+
6+
pub fn derive(input: &Input) -> TokenStream {
7+
let vec_name = &input.vec_name();
8+
let slice_name = &input.slice_name();
9+
let slice_mut_name = &input.slice_mut_name();
10+
let ref_name = &input.ref_name();
11+
let ref_mut_name = &input.ref_mut_name();
12+
let fields_names = input.fields.iter()
13+
.map(|field| field.ident.clone().unwrap())
14+
.collect::<Vec<_>>();
15+
let fields_names_1 = &fields_names;
16+
let fields_names_2 = &fields_names;
17+
18+
quote!{
19+
// usize
20+
impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for usize {
21+
type RefOutput = #ref_name<'a>;
22+
23+
#[inline]
24+
fn get(self, soa: &'a #vec_name) -> Option<Self::RefOutput> {
25+
if self < soa.len() {
26+
Some(unsafe { self.get_unchecked(soa) })
27+
} else {
28+
None
29+
}
30+
}
31+
32+
#[inline]
33+
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
34+
#ref_name {
35+
#(#fields_names_1: soa.#fields_names_2.get_unchecked(self),)*
36+
}
37+
}
38+
39+
#[inline]
40+
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
41+
#ref_name {
42+
#(#fields_names_1: & soa.#fields_names_2[self],)*
43+
}
44+
}
45+
}
46+
47+
impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for usize {
48+
type MutOutput = #ref_mut_name<'a>;
49+
50+
#[inline]
51+
fn get_mut(self, soa: &'a mut #vec_name) -> Option<Self::MutOutput> {
52+
if self < soa.len() {
53+
Some(unsafe { self.get_unchecked_mut(soa) })
54+
} else {
55+
None
56+
}
57+
}
58+
59+
#[inline]
60+
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
61+
#ref_mut_name {
62+
#(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self),)*
63+
}
64+
}
65+
66+
#[inline]
67+
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
68+
#ref_mut_name {
69+
#(#fields_names_1: &mut soa.#fields_names_2[self],)*
70+
}
71+
}
72+
}
73+
74+
75+
76+
// Range<usize>
77+
impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::Range<usize> {
78+
type RefOutput = #slice_name<'a>;
79+
80+
#[inline]
81+
fn get(self, soa: &'a #vec_name) -> Option<Self::RefOutput> {
82+
if self.start <= self.end && self.end <= soa.len() {
83+
unsafe { Some(self.get_unchecked(soa)) }
84+
} else {
85+
None
86+
}
87+
}
88+
89+
#[inline]
90+
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
91+
#slice_name {
92+
#(#fields_names_1: soa.#fields_names_2.get_unchecked(self.clone()),)*
93+
}
94+
}
95+
96+
#[inline]
97+
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
98+
#slice_name {
99+
#(#fields_names_1: & soa.#fields_names_2[self.clone()],)*
100+
}
101+
}
102+
}
103+
104+
impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::Range<usize> {
105+
type MutOutput = #slice_mut_name<'a>;
106+
107+
#[inline]
108+
fn get_mut(self, soa: &'a mut #vec_name) -> Option<Self::MutOutput> {
109+
if self.start <= self.end && self.end <= soa.len() {
110+
unsafe { Some(self.get_unchecked_mut(soa)) }
111+
} else {
112+
None
113+
}
114+
}
115+
116+
#[inline]
117+
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
118+
#slice_mut_name {
119+
#(#fields_names_1: soa.#fields_names_2.get_unchecked_mut(self.clone()),)*
120+
}
121+
}
122+
123+
#[inline]
124+
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
125+
#slice_mut_name {
126+
#(#fields_names_1: &mut soa.#fields_names_2[self.clone()],)*
127+
}
128+
}
129+
}
130+
131+
132+
133+
// RangeTo<usize>
134+
impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeTo<usize> {
135+
type RefOutput = #slice_name<'a>;
136+
137+
#[inline]
138+
fn get(self, soa: &'a #vec_name) -> Option<Self::RefOutput> {
139+
(0..self.end).get(soa)
140+
}
141+
142+
#[inline]
143+
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
144+
(0..self.end).get_unchecked(soa)
145+
}
146+
147+
#[inline]
148+
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
149+
(0..self.end).index(soa)
150+
}
151+
}
152+
153+
impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeTo<usize> {
154+
type MutOutput = #slice_mut_name<'a>;
155+
156+
#[inline]
157+
fn get_mut(self, soa: &'a mut #vec_name) -> Option<Self::MutOutput> {
158+
(0..self.end).get_mut(soa)
159+
}
160+
161+
#[inline]
162+
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
163+
(0..self.end).get_unchecked_mut(soa)
164+
}
165+
166+
#[inline]
167+
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
168+
(0..self.end).index_mut(soa)
169+
}
170+
}
171+
172+
173+
// RangeFrom<usize>
174+
impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeFrom<usize> {
175+
type RefOutput = #slice_name<'a>;
176+
177+
#[inline]
178+
fn get(self, soa: &'a #vec_name) -> Option<Self::RefOutput> {
179+
(self.start..soa.len()).get(soa)
180+
}
181+
182+
#[inline]
183+
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
184+
(self.start..soa.len()).get_unchecked(soa)
185+
}
186+
187+
#[inline]
188+
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
189+
(self.start..soa.len()).index(soa)
190+
}
191+
}
192+
193+
impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeFrom<usize> {
194+
type MutOutput = #slice_mut_name<'a>;
195+
196+
#[inline]
197+
fn get_mut(self, soa: &'a mut #vec_name) -> Option<Self::MutOutput> {
198+
(self.start..soa.len()).get_mut(soa)
199+
}
200+
201+
#[inline]
202+
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
203+
(self.start..soa.len()).get_unchecked_mut(soa)
204+
}
205+
206+
#[inline]
207+
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
208+
(self.start..soa.len()).index_mut(soa)
209+
}
210+
}
211+
212+
213+
// RangeFull
214+
impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeFull {
215+
type RefOutput = #slice_name<'a>;
216+
217+
#[inline]
218+
fn get(self, soa: &'a #vec_name) -> Option<Self::RefOutput> {
219+
Some(soa.as_slice())
220+
}
221+
222+
#[inline]
223+
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
224+
soa.as_slice()
225+
}
226+
227+
#[inline]
228+
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
229+
soa.as_slice()
230+
}
231+
}
232+
233+
impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeFull {
234+
type MutOutput = #slice_mut_name<'a>;
235+
236+
#[inline]
237+
fn get_mut(self, soa: &'a mut #vec_name) -> Option<Self::MutOutput> {
238+
Some(soa.as_mut_slice())
239+
}
240+
241+
#[inline]
242+
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
243+
soa.as_mut_slice()
244+
}
245+
246+
#[inline]
247+
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
248+
soa.as_mut_slice()
249+
}
250+
}
251+
252+
253+
// RangeInclusive<usize>
254+
impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeInclusive<usize> {
255+
type RefOutput = #slice_name<'a>;
256+
257+
#[inline]
258+
fn get(self, soa: &'a #vec_name) -> Option<Self::RefOutput> {
259+
if *self.end() == usize::MAX {
260+
None
261+
} else {
262+
(*self.start()..self.end() + 1).get(soa)
263+
}
264+
}
265+
266+
#[inline]
267+
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
268+
(*self.start()..self.end() + 1).get_unchecked(soa)
269+
}
270+
271+
#[inline]
272+
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
273+
(*self.start()..self.end() + 1).index(soa)
274+
}
275+
}
276+
277+
impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeInclusive<usize> {
278+
type MutOutput = #slice_mut_name<'a>;
279+
280+
#[inline]
281+
fn get_mut(self, soa: &'a mut #vec_name) -> Option<Self::MutOutput> {
282+
if *self.end() == usize::MAX {
283+
None
284+
} else {
285+
(*self.start()..self.end() + 1).get_mut(soa)
286+
}
287+
}
288+
289+
#[inline]
290+
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
291+
(*self.start()..self.end() + 1).get_unchecked_mut(soa)
292+
}
293+
294+
#[inline]
295+
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
296+
(*self.start()..self.end() + 1).index_mut(soa)
297+
}
298+
}
299+
300+
301+
// RangeToInclusive<usize>
302+
impl<'a> ::soa_derive::SoaIndex<&'a #vec_name> for ::std::ops::RangeToInclusive<usize> {
303+
type RefOutput = #slice_name<'a>;
304+
305+
#[inline]
306+
fn get(self, soa: &'a #vec_name) -> Option<Self::RefOutput> {
307+
(0..=self.end).get(soa)
308+
}
309+
310+
#[inline]
311+
unsafe fn get_unchecked(self, soa: &'a #vec_name) -> Self::RefOutput {
312+
(0..=self.end).get_unchecked(soa)
313+
}
314+
315+
#[inline]
316+
fn index(self, soa: &'a #vec_name) -> Self::RefOutput {
317+
(0..=self.end).index(soa)
318+
}
319+
}
320+
321+
impl<'a> ::soa_derive::SoaMutIndex<&'a mut #vec_name> for ::std::ops::RangeToInclusive<usize> {
322+
type MutOutput = #slice_mut_name<'a>;
323+
324+
#[inline]
325+
fn get_mut(self, soa: &'a mut #vec_name) -> Option<Self::MutOutput> {
326+
(0..=self.end).get_mut(soa)
327+
}
328+
329+
#[inline]
330+
unsafe fn get_unchecked_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
331+
(0..=self.end).get_unchecked_mut(soa)
332+
}
333+
334+
#[inline]
335+
fn index_mut(self, soa: &'a mut #vec_name) -> Self::MutOutput {
336+
(0..=self.end).index_mut(soa)
337+
}
338+
}
339+
}
340+
}

soa-derive-internal/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ extern crate proc_macro;
99
use proc_macro2::TokenStream;
1010
use quote::TokenStreamExt;
1111

12+
mod index;
1213
mod input;
1314
mod iter;
1415
mod ptr;
@@ -27,6 +28,7 @@ pub fn soa_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
2728
generated.append_all(ptr::derive(&input));
2829
generated.append_all(slice::derive(&input));
2930
generated.append_all(slice::derive_mut(&input));
31+
generated.append_all(index::derive(&input));
3032
generated.append_all(iter::derive(&input));
3133
generated.append_all(derive_trait(&input));
3234
generated.into()

0 commit comments

Comments
 (0)