Skip to content

Commit 10094de

Browse files
authored
refactor: optimize quote detection and writing (#75)
* refactor: optimize quote detection and writing in Writer using SIMD and string operations * fix: fix lint error
1 parent 463ce36 commit 10094de

File tree

1 file changed

+78
-23
lines changed

1 file changed

+78
-23
lines changed

writer.go

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"bufio"
77
"io"
88
"math/bits"
9+
"strings"
910
"unsafe"
1011

1112
"simd/archsimd"
@@ -95,6 +96,10 @@ func (w *Writer) Error() error {
9596
return w.err
9697
}
9798

99+
// writerSIMDMinSize is the minimum field size for SIMD benefit in Writer.
100+
// Smaller than the general simdMinThreshold because we use padded operations.
101+
const writerSIMDMinSize = 8
102+
98103
// fieldNeedsQuotes reports whether field requires quoting.
99104
// Dispatches to SIMD or scalar based on CPU support and field size.
100105
func (w *Writer) fieldNeedsQuotes(field string) bool {
@@ -105,15 +110,23 @@ func (w *Writer) fieldNeedsQuotes(field string) bool {
105110
if field[0] == ' ' || field[0] == '\t' {
106111
return true
107112
}
108-
// Use SIMD for ASCII delimiters and fields meeting size threshold
109-
if shouldUseSIMD(len(field)) && w.Comma >= 0 && w.Comma < 128 {
113+
// Use SIMD for ASCII delimiters (most common case)
114+
if useAVX512 && len(field) >= writerSIMDMinSize && w.Comma >= 0 && w.Comma < 128 {
110115
return w.fieldNeedsQuotesSIMD(field)
111116
}
112117
return w.fieldNeedsQuotesScalar(field)
113118
}
114119

115-
// fieldNeedsQuotesScalar checks for special characters using scalar iteration.
120+
// fieldNeedsQuotesScalar checks for special characters using optimized string search.
121+
// strings.IndexAny is internally optimized and uses SIMD on modern Go runtimes.
116122
func (w *Writer) fieldNeedsQuotesScalar(field string) bool {
123+
// For ASCII comma (common case), use IndexAny with precomputed charset
124+
if w.Comma < 128 {
125+
// Build search charset: comma + newline + carriage return + quote
126+
charset := string([]byte{byte(w.Comma), '\n', '\r', '"'})
127+
return strings.ContainsAny(field, charset)
128+
}
129+
// For non-ASCII comma, fall back to rune iteration
117130
for _, c := range field {
118131
if c == w.Comma || c == '\n' || c == '\r' || c == '"' {
119132
return true
@@ -123,6 +136,7 @@ func (w *Writer) fieldNeedsQuotesScalar(field string) bool {
123136
}
124137

125138
// fieldNeedsQuotesSIMD uses AVX-512 SIMD to detect special characters requiring quoting.
139+
// Handles any field size >= writerSIMDMinSize using padded operations for partial chunks.
126140
func (w *Writer) fieldNeedsQuotesSIMD(field string) bool {
127141
data := unsafe.Slice(unsafe.StringData(field), len(field))
128142
int8Data := bytesToInt8Slice(data)
@@ -145,10 +159,20 @@ func (w *Writer) fieldNeedsQuotesSIMD(field string) bool {
145159
offset += simdChunkSize
146160
}
147161

148-
// Process remaining bytes (< 64 bytes, scalar is sufficient)
149-
for ; offset < len(data); offset++ {
150-
c := data[offset]
151-
if c == byte(w.Comma) || c == '\n' || c == '\r' || c == '"' {
162+
// Process remaining bytes using SIMD with partial load
163+
if offset < len(data) {
164+
remaining := data[offset:]
165+
chunk := archsimd.LoadInt8x64SlicePart(bytesToInt8Slice(remaining))
166+
167+
commaMask := chunk.Equal(commaCmp).ToBits()
168+
newlineMask := chunk.Equal(cachedNlCmp).ToBits()
169+
crMask := chunk.Equal(cachedCrCmp).ToBits()
170+
quoteMask := chunk.Equal(cachedQuoteCmp).ToBits()
171+
172+
// Mask out bits beyond valid data
173+
validBits := len(remaining)
174+
mask := (uint64(1) << validBits) - 1
175+
if (commaMask|newlineMask|crMask|quoteMask)&mask != 0 {
152176
return true
153177
}
154178
}
@@ -160,29 +184,45 @@ func (w *Writer) writeQuotedField(field string) error {
160184
if err := w.w.WriteByte('"'); err != nil {
161185
return err
162186
}
163-
if shouldUseSIMD(len(field)) {
187+
if useAVX512 && len(field) >= writerSIMDMinSize {
164188
return w.writeQuotedFieldSIMD(field)
165189
}
166190
return w.writeQuotedFieldScalar(field)
167191
}
168192

169-
// writeQuotedFieldScalar escapes quotes using scalar iteration.
193+
// writeQuotedFieldScalar escapes quotes using optimized batch writing.
194+
// Instead of writing character by character, it finds quotes using IndexByte
195+
// and writes chunks between quotes in single WriteString calls.
170196
func (w *Writer) writeQuotedFieldScalar(field string) error {
171-
for _, c := range field {
172-
if c == '"' {
173-
if _, err := w.w.WriteString(`""`); err != nil {
174-
return err
175-
}
176-
} else {
177-
if _, err := w.w.WriteRune(c); err != nil {
178-
return err
179-
}
197+
lastWritten := 0
198+
for i := 0; i < len(field); {
199+
// Find next quote position from current offset
200+
idx := strings.IndexByte(field[i:], '"')
201+
if idx == -1 {
202+
break // No more quotes in remaining string
203+
}
204+
quotePos := i + idx
205+
// Write content up to and including the quote, then add escape quote
206+
if _, err := w.w.WriteString(field[lastWritten : quotePos+1]); err != nil {
207+
return err
208+
}
209+
if err := w.w.WriteByte('"'); err != nil {
210+
return err
211+
}
212+
lastWritten = quotePos + 1
213+
i = lastWritten
214+
}
215+
// Write remaining content after last quote (or entire field if no quotes)
216+
if lastWritten < len(field) {
217+
if _, err := w.w.WriteString(field[lastWritten:]); err != nil {
218+
return err
180219
}
181220
}
182221
return w.w.WriteByte('"')
183222
}
184223

185224
// writeQuotedFieldSIMD escapes quotes using AVX-512 SIMD to find quote positions.
225+
// Handles any field size >= writerSIMDMinSize using padded operations for partial chunks.
186226
func (w *Writer) writeQuotedFieldSIMD(field string) error {
187227
data := unsafe.Slice(unsafe.StringData(field), len(field))
188228
int8Data := bytesToInt8Slice(data)
@@ -213,16 +253,31 @@ func (w *Writer) writeQuotedFieldSIMD(field string) error {
213253
offset += simdChunkSize
214254
}
215255

216-
// Process remaining bytes (< 64 bytes, scalar is sufficient)
217-
for ; offset < len(data); offset++ {
218-
if data[offset] == '"' {
219-
if _, err := w.w.WriteString(field[lastWritten : offset+1]); err != nil {
256+
// Process remaining bytes using SIMD with partial load
257+
if offset < len(data) {
258+
remaining := data[offset:]
259+
chunk := archsimd.LoadInt8x64SlicePart(bytesToInt8Slice(remaining))
260+
mask := chunk.Equal(cachedQuoteCmp).ToBits()
261+
262+
// Mask out bits beyond valid data
263+
validBits := len(remaining)
264+
validMask := (uint64(1) << validBits) - 1
265+
mask &= validMask
266+
267+
for mask != 0 {
268+
pos := bits.TrailingZeros64(mask)
269+
quotePos := offset + pos
270+
271+
// Write content up to and including the quote, then add escape quote
272+
if _, err := w.w.WriteString(field[lastWritten : quotePos+1]); err != nil {
220273
return err
221274
}
222275
if err := w.w.WriteByte('"'); err != nil {
223276
return err
224277
}
225-
lastWritten = offset + 1
278+
279+
lastWritten = quotePos + 1
280+
mask &= ^(uint64(1) << pos)
226281
}
227282
}
228283

0 commit comments

Comments
 (0)