@@ -11,6 +11,8 @@ import (
1111 "regexp"
1212 "strconv"
1313 "strings"
14+ "sync"
15+ "sync/atomic"
1416 "time"
1517
1618 gocql "github.com/apache/cassandra-gocql-driver/v2"
@@ -474,11 +476,11 @@ func (h *MetaCommandHandler) handleCopyFrom(command string) interface{} {
474476 // Prepare for building INSERT statements
475477 columnList := strings .Join (columns , ", " )
476478
477- // Process rows
478- rowCount := 0 // Successfully imported rows
479- processedRows := 0 // Total rows processed (for MAXROWS check)
479+ // Process rows - use atomic counters for thread safety with concurrent workers
480+ var rowCount int64
481+ var insertErrorCount int64
482+ processedRows := 0 // Only accessed from main goroutine
480483 parseErrorCount := 0
481- insertErrorCount := 0
482484 skippedRows := 0
483485
484486 // Parse numeric options
@@ -488,9 +490,14 @@ func (h *MetaCommandHandler) handleCopyFrom(command string) interface{} {
488490 maxParseErrors , _ := strconv .Atoi (options ["MAXPARSEERRORS" ])
489491 maxInsertErrors , _ := strconv .Atoi (options ["MAXINSERTERRORS" ])
490492 maxBatchSize , _ := strconv .Atoi (options ["MAXBATCHSIZE" ])
491- minBatchSize , _ := strconv .Atoi (options ["MINBATCHSIZE " ])
493+ maxRequests , _ := strconv .Atoi (options ["MAXREQUESTS " ])
492494 nullVal := options ["NULLVAL" ]
493495
496+ // Ensure reasonable defaults
497+ if maxRequests < 1 {
498+ maxRequests = 6
499+ }
500+
494501 // Skip initial rows if specified
495502 for i := 0 ; i < skipRows ; i ++ {
496503 _ , err := csvReader .Read ()
@@ -500,8 +507,34 @@ func (h *MetaCommandHandler) handleCopyFrom(command string) interface{} {
500507 skippedRows ++
501508 }
502509
510+ // Create prepared statement template with placeholders
511+ placeholders := make ([]string , len (columns ))
512+ for i := range placeholders {
513+ placeholders [i ] = "?"
514+ }
515+ insertTemplate := fmt .Sprintf ("INSERT INTO %s (%s) VALUES (%s)" ,
516+ table , columnList , strings .Join (placeholders , ", " ))
517+
518+ // Create batch channel and wait group for concurrent execution
519+ batchChan := make (chan []batchEntry , maxRequests * 2 )
520+ var wg sync.WaitGroup
521+
522+ // Start worker goroutines for concurrent batch execution
523+ for i := 0 ; i < maxRequests ; i ++ {
524+ wg .Add (1 )
525+ go func () {
526+ defer wg .Done ()
527+ for batch := range batchChan {
528+ errors := h .executeBatchWithValues (batch )
529+ atomic .AddInt64 (& insertErrorCount , int64 (errors ))
530+ atomic .AddInt64 (& rowCount , int64 (len (batch )- errors ))
531+ }
532+ }()
533+ }
534+
503535 // Prepare batch for inserts
504- batch := make ([]string , 0 , maxBatchSize )
536+ batch := make ([]batchEntry , 0 , maxBatchSize )
537+ lastProgress := int64 (0 )
505538
506539 for {
507540 record , err := csvReader .Read ()
@@ -511,7 +544,9 @@ func (h *MetaCommandHandler) handleCopyFrom(command string) interface{} {
511544 if err != nil {
512545 parseErrorCount ++
513546 if maxParseErrors != - 1 && parseErrorCount > maxParseErrors {
514- return fmt .Sprintf ("Too many parse errors. Imported %d rows, failed after %d parse errors" , rowCount , parseErrorCount )
547+ close (batchChan )
548+ wg .Wait ()
549+ return fmt .Sprintf ("Too many parse errors. Imported %d rows, failed after %d parse errors" , atomic .LoadInt64 (& rowCount ), parseErrorCount )
515550 }
516551 continue
517552 }
@@ -526,84 +561,83 @@ func (h *MetaCommandHandler) handleCopyFrom(command string) interface{} {
526561 if len (record ) != len (columns ) {
527562 parseErrorCount ++
528563 if maxParseErrors != - 1 && parseErrorCount > maxParseErrors {
529- return fmt .Sprintf ("Too many parse errors. Imported %d rows, failed after %d parse errors" , rowCount , parseErrorCount )
564+ close (batchChan )
565+ wg .Wait ()
566+ return fmt .Sprintf ("Too many parse errors. Imported %d rows, failed after %d parse errors" , atomic .LoadInt64 (& rowCount ), parseErrorCount )
530567 }
531568 continue
532569 }
533570
534- // Convert values and build INSERT query
535- valueStrings := make ([]string , len (record ))
571+ // Convert values for prepared statement binding
572+ values := make ([]interface {} , len (record ))
536573 for i , val := range record {
537574 // Handle NULL values
538575 if val == nullVal {
539- valueStrings [i ] = "NULL"
576+ values [i ] = nil
540577 } else {
541- valueStrings [i ] = h .formatValueForInsert (val , columns [i ], table )
578+ values [i ] = h .parseValueForBinding (val , columns [i ], table )
542579 }
543580 }
544581
545- // Build INSERT query
546- insertQuery := fmt .Sprintf ("INSERT INTO %s (%s) VALUES (%s)" ,
547- table , columnList , strings .Join (valueStrings , ", " ))
582+ // Add to batch with prepared statement template
583+ batch = append (batch , batchEntry {query : insertTemplate , values : values })
548584
549- // Add to batch
550- batch = append (batch , insertQuery )
551-
552- // Execute batch if it reaches maxBatchSize
585+ // Send batch to workers if it reaches maxBatchSize
553586 if len (batch ) >= maxBatchSize {
554- errors := h .executeBatch (batch )
555- insertErrorCount += errors
556- if maxInsertErrors != - 1 && insertErrorCount > maxInsertErrors {
557- return fmt .Sprintf ("Too many insert errors. Imported %d rows, failed after %d insert errors" , rowCount , insertErrorCount )
587+ // Check for too many insert errors
588+ if maxInsertErrors != - 1 && atomic .LoadInt64 (& insertErrorCount ) > int64 (maxInsertErrors ) {
589+ close (batchChan )
590+ wg .Wait ()
591+ return fmt .Sprintf ("Too many insert errors. Imported %d rows, failed after %d insert errors" , atomic .LoadInt64 (& rowCount ), atomic .LoadInt64 (& insertErrorCount ))
558592 }
559- rowCount += len (batch ) - errors
593+ // Send batch to workers (make a copy since we reuse the slice)
594+ batchCopy := make ([]batchEntry , len (batch ))
595+ copy (batchCopy , batch )
596+ batchChan <- batchCopy
560597 batch = batch [:0 ] // Clear batch
561598 }
562599
563600 // Progress update for large imports
564- if rowCount % chunkSize == 0 && ! isStdin {
565- fmt .Printf ("\r Imported %d rows..." , rowCount )
601+ currentRows := atomic .LoadInt64 (& rowCount )
602+ if currentRows - lastProgress >= int64 (chunkSize ) && ! isStdin {
603+ fmt .Printf ("\r Imported %d rows..." , currentRows )
604+ lastProgress = currentRows
566605 }
567606 }
568607
569- // Execute any remaining batch
608+ // Send any remaining batch
570609 if len (batch ) > 0 {
571- if len (batch ) >= minBatchSize || rowCount == 0 {
572- errors := h .executeBatch (batch )
573- insertErrorCount += errors
574- rowCount += len (batch ) - errors
575- } else {
576- // Execute individually if below minBatchSize
577- for _ , query := range batch {
578- result := h .session .ExecuteCQLQuery (query )
579- if _ , ok := result .(error ); ok {
580- insertErrorCount ++
581- } else {
582- rowCount ++
583- }
584- }
585- }
610+ batchCopy := make ([]batchEntry , len (batch ))
611+ copy (batchCopy , batch )
612+ batchChan <- batchCopy
586613 }
587614
588- if ! isStdin && rowCount > chunkSize {
615+ // Close channel and wait for all workers to finish
616+ close (batchChan )
617+ wg .Wait ()
618+
619+ finalRowCount := atomic .LoadInt64 (& rowCount )
620+ finalInsertErrors := atomic .LoadInt64 (& insertErrorCount )
621+
622+ if ! isStdin && finalRowCount > int64 (chunkSize ) {
589623 fmt .Println () // New line after progress updates
590624 }
591625
592- totalErrors := parseErrorCount + insertErrorCount
626+ totalErrors := int64 ( parseErrorCount ) + finalInsertErrors
593627 if totalErrors > 0 {
594- details := fmt .Sprintf ("Imported %d rows from %s" , rowCount , filename )
628+ details := fmt .Sprintf ("Imported %d rows from %s" , finalRowCount , filename )
595629 if skipRows > 0 {
596630 details += fmt .Sprintf (" (skipped %d rows)" , skippedRows )
597631 }
598632 if parseErrorCount > 0 {
599633 details += fmt .Sprintf (" (%d parse errors)" , parseErrorCount )
600634 }
601- if insertErrorCount > 0 {
602- details += fmt .Sprintf (" (%d insert errors)" , insertErrorCount )
635+ if finalInsertErrors > 0 {
636+ details += fmt .Sprintf (" (%d insert errors)" , finalInsertErrors )
603637 }
604638 return details
605639 }
606- details := fmt .Sprintf ("Imported %d rows from %s" , rowCount , filename )
640+ details := fmt .Sprintf ("Imported %d rows from %s" , finalRowCount , filename )
607641 if skipRows > 0 {
608642 details += fmt .Sprintf (" (skipped %d rows)" , skippedRows )
609643 }
0 commit comments