Skip to content

Commit 780b457

Browse files
authored
refactor: split IQBPipeline and PipelineCacheManager (#44)
This diff moves the cache-management logic out of IQBPipeline and into a dedicated PipelineCacheManager type. In turn, this allows the type to be reused by `./iqb/cache.py` without having to instantiate a BigQuery client type. While there, use `PIPELINE_` prefix for `pipeline.py` constants.
1 parent 7c050f3 commit 780b457

File tree

2 files changed

+270
-76
lines changed

2 files changed

+270
-76
lines changed

library/src/iqb/pipeline.py

Lines changed: 95 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@
9898
}
9999

100100
# Cache file names
101-
CACHE_DATA_FILENAME: Final[str] = "data.parquet"
102-
CACHE_STATS_FILENAME: Final[str] = "stats.json"
101+
PIPELINE_CACHE_DATA_FILENAME: Final[str] = "data.parquet"
102+
PIPELINE_CACHE_STATS_FILENAME: Final[str] = "stats.json"
103103

104104

105105
@dataclass(frozen=True)
@@ -115,12 +115,37 @@ class PipelineCacheEntry:
115115
Reference to a cache entry containing query results and metadata.
116116
117117
Attributes:
118-
data_path: Path to data.parquet file
119-
stats_path: Path to stats.json file
118+
data_dir: the Path that points to the data dir
119+
tname: the ParsedTemplateName to use
120+
start_time: the datetime containing the start time
121+
end_time: the datetime containing the end time
120122
"""
121123

122-
data_path: Path
123-
stats_path: Path
124+
data_dir: Path
125+
tname: ParsedTemplateName
126+
start_time: datetime
127+
end_time: datetime
128+
129+
def dir_path(self) -> Path:
130+
"""Returns the directory path where to write files."""
131+
fs_date_format = "%Y%m%dT000000Z"
132+
start_dir = self.start_time.strftime(fs_date_format)
133+
end_dir = self.end_time.strftime(fs_date_format)
134+
return self.data_dir / "cache" / "v1" / start_dir / end_dir / self.tname.value
135+
136+
def data_path(self) -> Path | None:
137+
"""Returns the path to the parquet data file, if it exists, or None."""
138+
value = self.dir_path() / PIPELINE_CACHE_DATA_FILENAME
139+
if not value.exists():
140+
return None
141+
return value
142+
143+
def stats_path(self) -> Path | None:
144+
"""Returns the path to the JSON stats file, if it exists, or None."""
145+
value = self.dir_path() / PIPELINE_CACHE_STATS_FILENAME
146+
if not value.exists():
147+
return None
148+
return value
124149

125150

126151
@dataclass(frozen=True)
@@ -164,7 +189,7 @@ def save_parquet(self) -> ParquetFileInfo:
164189
If the query returns no rows, an empty parquet file is written.
165190
"""
166191
self.cache_dir.mkdir(parents=True, exist_ok=True)
167-
parquet_path = self.cache_dir / CACHE_DATA_FILENAME
192+
parquet_path = self.cache_dir / PIPELINE_CACHE_DATA_FILENAME
168193

169194
# Note: using .as_posix to avoid paths with backslashes
170195
# that can cause issues with PyArrow on Windows
@@ -191,7 +216,7 @@ def save_stats(self) -> Path:
191216
Path to the written stats.json file.
192217
"""
193218
self.cache_dir.mkdir(parents=True, exist_ok=True)
194-
stats_path = self.cache_dir / CACHE_STATS_FILENAME
219+
stats_path = self.cache_dir / PIPELINE_CACHE_STATS_FILENAME
195220

196221
# Calculate query duration from BigQuery job
197222
query_duration_seconds = None
@@ -218,6 +243,51 @@ def save_stats(self) -> Path:
218243
return stats_path
219244

220245

246+
class PipelineCacheManager:
247+
"""Manages the cache populated by the IQBPipeline."""
248+
249+
def __init__(self, data_dir: str | Path | None = None):
250+
"""
251+
Initialize cache with data directory path.
252+
253+
Parameters:
254+
data_dir: Path to directory containing cached data files.
255+
If None, defaults to .iqb/ in current working directory.
256+
"""
257+
self.data_dir = data_dir_or_default(data_dir)
258+
259+
def get_cache_entry(
260+
self,
261+
template: str,
262+
start_date: str,
263+
end_date: str,
264+
) -> PipelineCacheEntry:
265+
"""
266+
Get cache entry for the given query.
267+
268+
Args:
269+
template: name for the query template (e.g., "downloads_by_country")
270+
start_date: Date when to start the query (included) -- format YYYY-MM-DD
271+
end_date: Date when to end the query (excluded) -- format YYYY-MM-DD
272+
273+
Returns:
274+
PipelineCacheEntry with correctly initialized fields.
275+
"""
276+
# 1. parse the start and the end dates
277+
start_time, end_time = _parse_both_dates(start_date, end_date)
278+
279+
# 2. ensure the template name is correct
280+
tname = _parse_template_name(template)
281+
282+
# 3. return the corresponding entry
283+
return PipelineCacheEntry(
284+
data_dir=self.data_dir,
285+
tname=tname,
286+
start_time=start_time,
287+
end_time=end_time,
288+
)
289+
290+
221291
class IQBPipeline:
222292
"""Component for populating the IQB-measurement-data cache."""
223293

@@ -232,18 +302,7 @@ def __init__(self, project_id: str, data_dir: str | Path | None = None):
232302
"""
233303
self.client = bigquery.Client(project=project_id)
234304
self.bq_read_clnt = bigquery_storage_v1.BigQueryReadClient()
235-
self.data_dir = data_dir_or_default(data_dir)
236-
237-
def _cache_dir_path(
238-
self,
239-
tname: ParsedTemplateName,
240-
start_time: datetime,
241-
end_time: datetime,
242-
) -> Path:
243-
fs_date_format = "%Y%m%dT000000Z"
244-
start_dir = start_time.strftime(fs_date_format)
245-
end_dir = end_time.strftime(fs_date_format)
246-
return self.data_dir / "cache" / "v1" / start_dir / end_dir / tname.value
305+
self.manager = PipelineCacheManager(data_dir)
247306

248307
def get_cache_entry(
249308
self,
@@ -269,36 +328,30 @@ def get_cache_entry(
269328
Raises:
270329
FileNotFoundError: if cache doesn't exist and fetch_if_missing is False.
271330
"""
272-
# 1. parse the start and the end dates
273-
start_time, end_time = _parse_both_dates(start_date, end_date)
331+
# 1. get the cache entry
332+
entry = self.manager.get_cache_entry(template, start_date, end_date)
274333

275-
# 2. ensure the template name is correct
276-
tname = _parse_template_name(template)
277-
278-
# 3. obtain information about the cache dir and files
279-
cache_dir = self._cache_dir_path(tname, start_time, end_time)
280-
data_path = cache_dir / CACHE_DATA_FILENAME
281-
stats_path = cache_dir / CACHE_STATS_FILENAME
282-
283-
# 4. check if cache exists
284-
if data_path.exists() and stats_path.exists():
285-
return PipelineCacheEntry(data_path=data_path, stats_path=stats_path)
334+
# 2. make sure the entry exists
335+
if entry.data_path() is not None and entry.stats_path() is not None:
336+
return entry
286337

287-
# 5. handle missing cache without auto-fetching
338+
# 3. handle missing cache without auto-fetching
288339
if not fetch_if_missing:
289340
raise FileNotFoundError(
290341
f"Cache entry not found for {template} "
291342
f"({start_date} to {end_date}). "
292343
f"Set fetch_if_missing=True to execute query."
293344
)
294345

295-
# 6. execute query and update the cache
296-
result = self._execute_query_template(tname, start_time, end_time)
346+
# 4. execute query and update the cache
347+
result = self._execute_query_template(entry)
297348
result.save_parquet()
298349
result.save_stats()
299350

300-
# 7. return information about the cache entry
301-
return PipelineCacheEntry(data_path=data_path, stats_path=stats_path)
351+
# 5. return information about the cache entry
352+
assert entry.data_path() is not None
353+
assert entry.stats_path() is not None
354+
return entry
302355

303356
def execute_query_template(
304357
self,
@@ -317,27 +370,13 @@ def execute_query_template(
317370
Returns:
318371
A QueryResult instance.
319372
"""
320-
# 1. parse the start and the end dates
321-
start_time, end_time = _parse_both_dates(start_date, end_date)
322-
323-
# 2. ensure the template name is correct
324-
tname = _parse_template_name(template)
325-
326-
# 3. defer to the private implementation
327373
return self._execute_query_template(
328-
tname=tname,
329-
start_time=start_time,
330-
end_time=end_time,
374+
self.manager.get_cache_entry(template, start_date, end_date)
331375
)
332376

333-
def _execute_query_template(
334-
self,
335-
tname: ParsedTemplateName,
336-
start_time: datetime,
337-
end_time: datetime,
338-
) -> QueryResult:
377+
def _execute_query_template(self, entry: PipelineCacheEntry) -> QueryResult:
339378
# 1. load the actual query
340-
query, template_hash = _load_query_template(tname, start_time, end_time)
379+
query, template_hash = _load_query_template(entry.tname, entry.start_time, entry.end_time)
341380

342381
# 2. record query start time (RFC3339 format with Z suffix)
343382
query_start_time = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
@@ -347,7 +386,7 @@ def _execute_query_template(
347386
rows = job.result()
348387

349388
# 4. compute the directory where we would save the results
350-
cache_dir = self._cache_dir_path(tname, start_time, end_time)
389+
cache_dir = entry.dir_path()
351390

352391
# 5. return the result object
353392
return QueryResult(

0 commit comments

Comments
 (0)