Skip to content

Commit 82e64ef

Browse files
committed
Some inline asm
1 parent 6c576c3 commit 82e64ef

File tree

1 file changed

+181
-32
lines changed

1 file changed

+181
-32
lines changed

src/day25.rs

Lines changed: 181 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,27 @@ pub fn part1(s: &str) -> u64 {
88
unsafe { part1_inner(s) }
99
}
1010

11+
pub fn part2(_s: &str) -> u64 {
12+
// To be sure you know...
13+
42
14+
}
15+
1116
const DS: usize = 7 * 6 + 1;
1217

18+
static LUT: [u32; 4] = [0, 0xFF, 0xFF_FF, 0xFF_FF_FF];
19+
1320
#[inline(always)]
1421
unsafe fn part1_inner(s: &[u8]) -> u64 {
1522
let mut sum = 0;
1623

17-
let mut keys = heapless::Vec::<u64, 512>::new();
18-
let mut holes = heapless::Vec::<u64, 512>::new();
24+
static mut KEYS: [u64; 512] = unsafe { std::mem::transmute([0u8; 512 * 8]) };
25+
static mut HOLES: [u64; 512] = unsafe { std::mem::transmute([0u8; 512 * 8]) };
26+
27+
let keys = &mut *(&raw mut KEYS);
28+
let holes = &mut *(&raw mut HOLES);
29+
30+
let mut keys_i = 0;
31+
let mut holes_i = 0;
1932

2033
let mut i = 0;
2134

@@ -50,38 +63,174 @@ unsafe fn part1_inner(s: &[u8]) -> u64 {
5063
.read_unaligned()
5164
& 0x0101010101);
5265

53-
let other = if is_key { &holes } else { &keys };
54-
let mut j = other.len();
55-
while j >= 4 {
56-
j -= 4;
57-
let o = other
58-
.as_ptr()
59-
.offset(j as isize)
60-
.cast::<__m256i>()
61-
.read_unaligned();
62-
let s = _mm256_add_epi64(o, _mm256_set1_epi64x(d as i64));
63-
let s = _mm256_and_si256(s, _mm256_set1_epi8(0x80u8 as i8));
64-
let s = _mm256_cmpeq_epi64(s, _mm256_set1_epi64x(0));
65-
let s = _mm256_movemask_epi8(s) as u32;
66-
67-
sum += s.count_ones() as u64 / 8;
68-
}
69-
if j > 0 {
70-
let o = other.as_ptr().cast::<__m256i>().read_unaligned();
71-
let s = _mm256_add_epi64(o, _mm256_set1_epi64x(d as i64));
72-
let s = _mm256_and_si256(s, _mm256_set1_epi8(0x80u8 as i8));
73-
let s = _mm256_cmpeq_epi64(s, _mm256_set1_epi64x(0));
74-
let s = _mm256_movemask_epi8(s) as u32;
75-
76-
let s = s & [0, 0xFF, 0xFF_FF, 0xFF_FF_FF].get_unchecked(j);
77-
sum += s.count_ones() as u64 / 8;
78-
}
79-
80-
let d = d + 0x7A7A7A7A7A;
8166
if is_key {
82-
keys.push_unchecked(d);
67+
std::arch::asm!(
68+
"test {max_i}, {max_i}",
69+
"je 2f", // Jump on empty
70+
"mov {i}, {max_i}",
71+
"cmp {i}, 16",
72+
"jb 6f", // Jump to < 16 case
73+
74+
"4:",
75+
"add {i}, -16",
76+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*3]",
77+
"vpand {vt}, {vt}, {msb}",
78+
"vpcmpeqq {vt}, {vt}, {zero}",
79+
"vpmovmskb {t}, {vt}",
80+
"popcnt {t}, {t}",
81+
"shr {t}, 3",
82+
"add {sum},{t}",
83+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*2]",
84+
"vpand {vt}, {vt}, {msb}",
85+
"vpcmpeqq {vt}, {vt}, {zero}",
86+
"vpmovmskb {t}, {vt}",
87+
"popcnt {t}, {t}",
88+
"shr {t}, 3",
89+
"add {sum},{t}",
90+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*1]",
91+
"vpand {vt}, {vt}, {msb}",
92+
"vpcmpeqq {vt}, {vt}, {zero}",
93+
"vpmovmskb {t}, {vt}",
94+
"popcnt {t}, {t}",
95+
"shr {t}, 3",
96+
"add {sum},{t}",
97+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*0]",
98+
"vpand {vt}, {vt}, {msb}",
99+
"vpcmpeqq {vt}, {vt}, {zero}",
100+
"vpmovmskb {t}, {vt}",
101+
"popcnt {t}, {t}",
102+
"shr {t}, 3",
103+
"add {sum},{t}",
104+
"cmp {i}, 16",
105+
"jae 4b", // Loop
106+
"6:",
107+
"cmp {i}, 4",
108+
"jb 3f", // Is < 4
109+
// Is >= 4 and < 16
110+
111+
"5:",
112+
"add {i}, -4",
113+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i}]",
114+
"vpand {vt}, {vt}, {msb}",
115+
"vpcmpeqq {vt}, {vt}, {zero}",
116+
"vpmovmskb {t}, {vt}",
117+
"popcnt {t}, {t}",
118+
"shr {t}, 3",
119+
"add {sum},{t}",
120+
"cmp {i}, 4",
121+
"jae 5b", // Loop
122+
"3:",
123+
"test {i}, {i}",
124+
"je 2f", // Is zero
125+
126+
// Is > 0 and < 4
127+
"vpaddq {vt}, {d}, ymmword ptr [{os}]",
128+
"vpand {vt}, {vt}, {msb}",
129+
"vpcmpeqq {vt}, {vt}, {zero}",
130+
"vpmovmskb {t}, {vt}",
131+
"and {t:e}, dword ptr [{lut} + 4*{i}]",
132+
"popcnt {t}, {t}",
133+
"shr {t}, 3",
134+
"add {sum},{t}",
135+
"2:",
136+
d = in(ymm_reg) _mm256_set1_epi64x(d as i64),
137+
msb = in(ymm_reg) _mm256_set1_epi8(0x80u8 as i8),
138+
zero = in(ymm_reg) _mm256_set1_epi64x(0),
139+
lut = in(reg) LUT.as_ptr(),
140+
os = in(reg) holes,
141+
max_i = in(reg) holes_i,
142+
sum = inout(reg) sum,
143+
i = out(reg) _,
144+
t = out(reg) _,
145+
vt = out(ymm_reg) _,
146+
options(nostack),
147+
);
148+
*keys.get_unchecked_mut(keys_i) = d + 0x7A7A7A7A7A;
149+
keys_i += 1;
83150
} else {
84-
holes.push_unchecked(d);
151+
std::arch::asm!(
152+
"test {max_i}, {max_i}",
153+
"je 2f", // Jump on empty
154+
"mov {i}, {max_i}",
155+
"cmp {i}, 16",
156+
"jb 6f", // Jump to < 16 case
157+
158+
"4:",
159+
"add {i}, -16",
160+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*3]",
161+
"vpand {vt}, {vt}, {msb}",
162+
"vpcmpeqq {vt}, {vt}, {zero}",
163+
"vpmovmskb {t}, {vt}",
164+
"popcnt {t}, {t}",
165+
"shr {t}, 3",
166+
"add {sum},{t}",
167+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*2]",
168+
"vpand {vt}, {vt}, {msb}",
169+
"vpcmpeqq {vt}, {vt}, {zero}",
170+
"vpmovmskb {t}, {vt}",
171+
"popcnt {t}, {t}",
172+
"shr {t}, 3",
173+
"add {sum},{t}",
174+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*1]",
175+
"vpand {vt}, {vt}, {msb}",
176+
"vpcmpeqq {vt}, {vt}, {zero}",
177+
"vpmovmskb {t}, {vt}",
178+
"popcnt {t}, {t}",
179+
"shr {t}, 3",
180+
"add {sum},{t}",
181+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*0]",
182+
"vpand {vt}, {vt}, {msb}",
183+
"vpcmpeqq {vt}, {vt}, {zero}",
184+
"vpmovmskb {t}, {vt}",
185+
"popcnt {t}, {t}",
186+
"shr {t}, 3",
187+
"add {sum},{t}",
188+
"cmp {i}, 16",
189+
"jae 4b", // Loop
190+
"6:",
191+
"cmp {i}, 4",
192+
"jb 3f", // Is < 4
193+
// Is >= 4 and < 16
194+
195+
"5:",
196+
"add {i}, -4",
197+
"vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i}]",
198+
"vpand {vt}, {vt}, {msb}",
199+
"vpcmpeqq {vt}, {vt}, {zero}",
200+
"vpmovmskb {t}, {vt}",
201+
"popcnt {t}, {t}",
202+
"shr {t}, 3",
203+
"add {sum},{t}",
204+
"cmp {i}, 4",
205+
"jae 5b", // Loop
206+
"3:",
207+
"test {i}, {i}",
208+
"je 2f", // Is zero
209+
210+
// Is > 0 and < 4
211+
"vpaddq {vt}, {d}, ymmword ptr [{os}]",
212+
"vpand {vt}, {vt}, {msb}",
213+
"vpcmpeqq {vt}, {vt}, {zero}",
214+
"vpmovmskb {t}, {vt}",
215+
"and {t:e}, dword ptr [{lut} + 4*{i}]",
216+
"popcnt {t}, {t}",
217+
"shr {t}, 3",
218+
"add {sum},{t}",
219+
"2:",
220+
d = in(ymm_reg) _mm256_set1_epi64x(d as i64),
221+
msb = in(ymm_reg) _mm256_set1_epi8(0x80u8 as i8),
222+
zero = in(ymm_reg) _mm256_set1_epi64x(0),
223+
lut = in(reg) LUT.as_ptr(),
224+
os = in(reg) keys,
225+
max_i = in(reg) keys_i,
226+
sum = inout(reg) sum,
227+
i = out(reg) _,
228+
t = out(reg) _,
229+
vt = out(ymm_reg) _,
230+
options(nostack),
231+
);
232+
*holes.get_unchecked_mut(holes_i) = d + 0x7A7A7A7A7A;
233+
holes_i += 1;
85234
}
86235

87236
i += DS;

0 commit comments

Comments
 (0)