@@ -8,14 +8,27 @@ pub fn part1(s: &str) -> u64 {
88 unsafe { part1_inner ( s) }
99}
1010
11+ pub fn part2 ( _s : & str ) -> u64 {
12+ // To be sure you know...
13+ 42
14+ }
15+
1116const DS : usize = 7 * 6 + 1 ;
1217
18+ static LUT : [ u32 ; 4 ] = [ 0 , 0xFF , 0xFF_FF , 0xFF_FF_FF ] ;
19+
1320#[ inline( always) ]
1421unsafe fn part1_inner ( s : & [ u8 ] ) -> u64 {
1522 let mut sum = 0 ;
1623
17- let mut keys = heapless:: Vec :: < u64 , 512 > :: new ( ) ;
18- let mut holes = heapless:: Vec :: < u64 , 512 > :: new ( ) ;
24+ static mut KEYS : [ u64 ; 512 ] = unsafe { std:: mem:: transmute ( [ 0u8 ; 512 * 8 ] ) } ;
25+ static mut HOLES : [ u64 ; 512 ] = unsafe { std:: mem:: transmute ( [ 0u8 ; 512 * 8 ] ) } ;
26+
27+ let keys = & mut * ( & raw mut KEYS ) ;
28+ let holes = & mut * ( & raw mut HOLES ) ;
29+
30+ let mut keys_i = 0 ;
31+ let mut holes_i = 0 ;
1932
2033 let mut i = 0 ;
2134
@@ -50,38 +63,174 @@ unsafe fn part1_inner(s: &[u8]) -> u64 {
5063 . read_unaligned ( )
5164 & 0x0101010101 ) ;
5265
53- let other = if is_key { & holes } else { & keys } ;
54- let mut j = other. len ( ) ;
55- while j >= 4 {
56- j -= 4 ;
57- let o = other
58- . as_ptr ( )
59- . offset ( j as isize )
60- . cast :: < __m256i > ( )
61- . read_unaligned ( ) ;
62- let s = _mm256_add_epi64 ( o, _mm256_set1_epi64x ( d as i64 ) ) ;
63- let s = _mm256_and_si256 ( s, _mm256_set1_epi8 ( 0x80u8 as i8 ) ) ;
64- let s = _mm256_cmpeq_epi64 ( s, _mm256_set1_epi64x ( 0 ) ) ;
65- let s = _mm256_movemask_epi8 ( s) as u32 ;
66-
67- sum += s. count_ones ( ) as u64 / 8 ;
68- }
69- if j > 0 {
70- let o = other. as_ptr ( ) . cast :: < __m256i > ( ) . read_unaligned ( ) ;
71- let s = _mm256_add_epi64 ( o, _mm256_set1_epi64x ( d as i64 ) ) ;
72- let s = _mm256_and_si256 ( s, _mm256_set1_epi8 ( 0x80u8 as i8 ) ) ;
73- let s = _mm256_cmpeq_epi64 ( s, _mm256_set1_epi64x ( 0 ) ) ;
74- let s = _mm256_movemask_epi8 ( s) as u32 ;
75-
76- let s = s & [ 0 , 0xFF , 0xFF_FF , 0xFF_FF_FF ] . get_unchecked ( j) ;
77- sum += s. count_ones ( ) as u64 / 8 ;
78- }
79-
80- let d = d + 0x7A7A7A7A7A ;
8166 if is_key {
82- keys. push_unchecked ( d) ;
67+ std:: arch:: asm!(
68+ "test {max_i}, {max_i}" ,
69+ "je 2f" , // Jump on empty
70+ "mov {i}, {max_i}" ,
71+ "cmp {i}, 16" ,
72+ "jb 6f" , // Jump to < 16 case
73+
74+ "4:" ,
75+ "add {i}, -16" ,
76+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*3]" ,
77+ "vpand {vt}, {vt}, {msb}" ,
78+ "vpcmpeqq {vt}, {vt}, {zero}" ,
79+ "vpmovmskb {t}, {vt}" ,
80+ "popcnt {t}, {t}" ,
81+ "shr {t}, 3" ,
82+ "add {sum},{t}" ,
83+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*2]" ,
84+ "vpand {vt}, {vt}, {msb}" ,
85+ "vpcmpeqq {vt}, {vt}, {zero}" ,
86+ "vpmovmskb {t}, {vt}" ,
87+ "popcnt {t}, {t}" ,
88+ "shr {t}, 3" ,
89+ "add {sum},{t}" ,
90+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*1]" ,
91+ "vpand {vt}, {vt}, {msb}" ,
92+ "vpcmpeqq {vt}, {vt}, {zero}" ,
93+ "vpmovmskb {t}, {vt}" ,
94+ "popcnt {t}, {t}" ,
95+ "shr {t}, 3" ,
96+ "add {sum},{t}" ,
97+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*0]" ,
98+ "vpand {vt}, {vt}, {msb}" ,
99+ "vpcmpeqq {vt}, {vt}, {zero}" ,
100+ "vpmovmskb {t}, {vt}" ,
101+ "popcnt {t}, {t}" ,
102+ "shr {t}, 3" ,
103+ "add {sum},{t}" ,
104+ "cmp {i}, 16" ,
105+ "jae 4b" , // Loop
106+ "6:" ,
107+ "cmp {i}, 4" ,
108+ "jb 3f" , // Is < 4
109+ // Is >= 4 and < 16
110+
111+ "5:" ,
112+ "add {i}, -4" ,
113+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i}]" ,
114+ "vpand {vt}, {vt}, {msb}" ,
115+ "vpcmpeqq {vt}, {vt}, {zero}" ,
116+ "vpmovmskb {t}, {vt}" ,
117+ "popcnt {t}, {t}" ,
118+ "shr {t}, 3" ,
119+ "add {sum},{t}" ,
120+ "cmp {i}, 4" ,
121+ "jae 5b" , // Loop
122+ "3:" ,
123+ "test {i}, {i}" ,
124+ "je 2f" , // Is zero
125+
126+ // Is > 0 and < 4
127+ "vpaddq {vt}, {d}, ymmword ptr [{os}]" ,
128+ "vpand {vt}, {vt}, {msb}" ,
129+ "vpcmpeqq {vt}, {vt}, {zero}" ,
130+ "vpmovmskb {t}, {vt}" ,
131+ "and {t:e}, dword ptr [{lut} + 4*{i}]" ,
132+ "popcnt {t}, {t}" ,
133+ "shr {t}, 3" ,
134+ "add {sum},{t}" ,
135+ "2:" ,
136+ d = in( ymm_reg) _mm256_set1_epi64x( d as i64 ) ,
137+ msb = in( ymm_reg) _mm256_set1_epi8( 0x80u8 as i8 ) ,
138+ zero = in( ymm_reg) _mm256_set1_epi64x( 0 ) ,
139+ lut = in( reg) LUT . as_ptr( ) ,
140+ os = in( reg) holes,
141+ max_i = in( reg) holes_i,
142+ sum = inout( reg) sum,
143+ i = out( reg) _,
144+ t = out( reg) _,
145+ vt = out( ymm_reg) _,
146+ options( nostack) ,
147+ ) ;
148+ * keys. get_unchecked_mut ( keys_i) = d + 0x7A7A7A7A7A ;
149+ keys_i += 1 ;
83150 } else {
84- holes. push_unchecked ( d) ;
151+ std:: arch:: asm!(
152+ "test {max_i}, {max_i}" ,
153+ "je 2f" , // Jump on empty
154+ "mov {i}, {max_i}" ,
155+ "cmp {i}, 16" ,
156+ "jb 6f" , // Jump to < 16 case
157+
158+ "4:" ,
159+ "add {i}, -16" ,
160+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*3]" ,
161+ "vpand {vt}, {vt}, {msb}" ,
162+ "vpcmpeqq {vt}, {vt}, {zero}" ,
163+ "vpmovmskb {t}, {vt}" ,
164+ "popcnt {t}, {t}" ,
165+ "shr {t}, 3" ,
166+ "add {sum},{t}" ,
167+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*2]" ,
168+ "vpand {vt}, {vt}, {msb}" ,
169+ "vpcmpeqq {vt}, {vt}, {zero}" ,
170+ "vpmovmskb {t}, {vt}" ,
171+ "popcnt {t}, {t}" ,
172+ "shr {t}, 3" ,
173+ "add {sum},{t}" ,
174+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*1]" ,
175+ "vpand {vt}, {vt}, {msb}" ,
176+ "vpcmpeqq {vt}, {vt}, {zero}" ,
177+ "vpmovmskb {t}, {vt}" ,
178+ "popcnt {t}, {t}" ,
179+ "shr {t}, 3" ,
180+ "add {sum},{t}" ,
181+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i} + 32*0]" ,
182+ "vpand {vt}, {vt}, {msb}" ,
183+ "vpcmpeqq {vt}, {vt}, {zero}" ,
184+ "vpmovmskb {t}, {vt}" ,
185+ "popcnt {t}, {t}" ,
186+ "shr {t}, 3" ,
187+ "add {sum},{t}" ,
188+ "cmp {i}, 16" ,
189+ "jae 4b" , // Loop
190+ "6:" ,
191+ "cmp {i}, 4" ,
192+ "jb 3f" , // Is < 4
193+ // Is >= 4 and < 16
194+
195+ "5:" ,
196+ "add {i}, -4" ,
197+ "vpaddq {vt}, {d}, ymmword ptr [{os} + 8*{i}]" ,
198+ "vpand {vt}, {vt}, {msb}" ,
199+ "vpcmpeqq {vt}, {vt}, {zero}" ,
200+ "vpmovmskb {t}, {vt}" ,
201+ "popcnt {t}, {t}" ,
202+ "shr {t}, 3" ,
203+ "add {sum},{t}" ,
204+ "cmp {i}, 4" ,
205+ "jae 5b" , // Loop
206+ "3:" ,
207+ "test {i}, {i}" ,
208+ "je 2f" , // Is zero
209+
210+ // Is > 0 and < 4
211+ "vpaddq {vt}, {d}, ymmword ptr [{os}]" ,
212+ "vpand {vt}, {vt}, {msb}" ,
213+ "vpcmpeqq {vt}, {vt}, {zero}" ,
214+ "vpmovmskb {t}, {vt}" ,
215+ "and {t:e}, dword ptr [{lut} + 4*{i}]" ,
216+ "popcnt {t}, {t}" ,
217+ "shr {t}, 3" ,
218+ "add {sum},{t}" ,
219+ "2:" ,
220+ d = in( ymm_reg) _mm256_set1_epi64x( d as i64 ) ,
221+ msb = in( ymm_reg) _mm256_set1_epi8( 0x80u8 as i8 ) ,
222+ zero = in( ymm_reg) _mm256_set1_epi64x( 0 ) ,
223+ lut = in( reg) LUT . as_ptr( ) ,
224+ os = in( reg) keys,
225+ max_i = in( reg) keys_i,
226+ sum = inout( reg) sum,
227+ i = out( reg) _,
228+ t = out( reg) _,
229+ vt = out( ymm_reg) _,
230+ options( nostack) ,
231+ ) ;
232+ * holes. get_unchecked_mut ( holes_i) = d + 0x7A7A7A7A7A ;
233+ holes_i += 1 ;
85234 }
86235
87236 i += DS ;
0 commit comments