44"""A module containing load method definition."""
55
66import logging
7- import re
87from io import BytesIO
98
109import pandas as pd
1110
1211from graphrag .config .models .input_config import InputConfig
13- from graphrag .index .utils . hashing import gen_sha512_hash
12+ from graphrag .index .input . util import load_files , process_data_columns
1413from graphrag .logger .base import ProgressLogger
1514from graphrag .storage .pipeline_storage import PipelineStorage
1615
1716log = logging .getLogger (__name__ )
1817
19- DEFAULT_FILE_PATTERN = re .compile (r"(?P<filename>[^\\/]).csv$" )
2018
21- input_type = "csv"
22-
23-
24- async def load (
19+ async def load_csv (
2520 config : InputConfig ,
2621 progress : ProgressLogger | None ,
2722 storage : PipelineStorage ,
@@ -39,61 +34,12 @@ async def load_file(path: str, group: dict | None) -> pd.DataFrame:
3934 data [[* additional_keys ]] = data .apply (
4035 lambda _row : pd .Series ([group [key ] for key in additional_keys ]), axis = 1
4136 )
42- if "id" not in data .columns :
43- data ["id" ] = data .apply (lambda x : gen_sha512_hash (x , x .keys ()), axis = 1 )
44- if config .text_column is not None and "text" not in data .columns :
45- if config .text_column not in data .columns :
46- log .warning (
47- "text_column %s not found in csv file %s" ,
48- config .text_column ,
49- path ,
50- )
51- else :
52- data ["text" ] = data .apply (lambda x : x [config .text_column ], axis = 1 )
53- if config .title_column is not None :
54- if config .title_column not in data .columns :
55- log .warning (
56- "title_column %s not found in csv file %s" ,
57- config .title_column ,
58- path ,
59- )
60- else :
61- data ["title" ] = data .apply (lambda x : x [config .title_column ], axis = 1 )
62- else :
63- data ["title" ] = data .apply (lambda _ : path , axis = 1 )
37+
38+ data = process_data_columns (data , config , path )
6439
6540 creation_date = await storage .get_creation_date (path )
6641 data ["creation_date" ] = data .apply (lambda _ : creation_date , axis = 1 )
6742
6843 return data
6944
70- file_pattern = (
71- re .compile (config .file_pattern )
72- if config .file_pattern is not None
73- else DEFAULT_FILE_PATTERN
74- )
75- files = list (
76- storage .find (
77- file_pattern ,
78- progress = progress ,
79- file_filter = config .file_filter ,
80- )
81- )
82-
83- if len (files ) == 0 :
84- msg = f"No CSV files found in { config .base_dir } "
85- raise ValueError (msg )
86-
87- files_loaded = []
88-
89- for file , group in files :
90- try :
91- files_loaded .append (await load_file (file , group ))
92- except Exception : # noqa: BLE001 (catching Exception is fine here)
93- log .warning ("Warning! Error loading csv file %s. Skipping..." , file )
94-
95- log .info ("Found %d csv files, loading %d" , len (files ), len (files_loaded ))
96- result = pd .concat (files_loaded )
97- total_files_log = f"Total number of unfiltered csv rows: { len (result )} "
98- log .info (total_files_log )
99- return result
45+ return await load_files (load_file , config , storage , progress )
0 commit comments