Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions ns_extract/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __post_init__(self):
@dataclass
class ProcessedData:
coordinates: Optional[Path] = None
text: Path = None
text: Optional[Path] = None
metadata: Optional[Path] = None
raw: Optional[Union["PubgetRaw", "AceRaw"]] = field(default=None)

Expand All @@ -122,6 +122,7 @@ class Study:
pmcid: str = None
ace: ProcessedData = None
pubget: ProcessedData = None
db: ProcessedData = None
pipeline_results: Dict[str, PipelineRunResult] = field(default_factory=dict)

def __post_init__(self):
Expand Down Expand Up @@ -154,7 +155,7 @@ def __post_init__(self):
pubget_raw = PubgetRaw(xml=pubget_xml_path, tables_xml=tables_xml_path)

# Load processed data
for t in ["ace", "pubget"]:
for t in ["ace", "pubget", "db"]:
processed_dir = self.study_dir / "processed" / t
if processed_dir.exists():
try:
Expand All @@ -165,7 +166,7 @@ def __post_init__(self):
coordinates=(
coordinates_path if coordinates_path.exists() else None
),
text=text_path,
text=text_path if text_path.exists() else None,
metadata=metadata_path if metadata_path.exists() else None,
raw=ace_raw if t == "ace" else pubget_raw,
)
Expand Down
73 changes: 41 additions & 32 deletions ns_extract/pipelines/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ def _transform(self, inputs: dict, **kwargs) -> dict:
results = {}
for study_id, study_inputs in inputs.items():
# Get text content - already loaded by InputManager
text = study_inputs["text"]

text = study_inputs.get("text", "")
if not text:
logging.warning(f"No text found for study {study_id}")
results[study_id] = {}
continue
# Create chat completion configuration
completion_config = {
"messages": [
Expand Down Expand Up @@ -175,7 +178,9 @@ def get_nested_value(data: dict, key_path: str):
"""Access nested dictionary values using dot notation."""
keys = key_path.split(".")
for key in keys:
data = data[key]
data = data.get(key, {})
if not data:
return None
return data


Expand All @@ -199,6 +204,7 @@ def __init__(
**kwargs,
):
self.text_source = text_source
self.enc = tiktoken.encoding_for_model(extraction_model)
super().__init__(
extraction_model=extraction_model,
env_variable=env_variable,
Expand All @@ -208,6 +214,33 @@ def __init__(
**kwargs,
)

def chunk_paragraph(self, paragraph, max_tokens=MAX_TOKENS):
tokens = self.enc.encode(paragraph)
if len(tokens) <= max_tokens and len(tokens) >= MINIMUM_CHUNK_SIZE:
return [paragraph]
if len(tokens) < MINIMUM_CHUNK_SIZE:
return []
# Use spaCy to split into sentences
if not self._nlp_initialized:
self._init_nlp_components()
doc = self._nlp(paragraph)
sentences = [sent.text for sent in doc.sents]
chunks = []
current_chunk = ""
for sent in sentences:
test_chunk = current_chunk + " " + sent if current_chunk else sent
if len(self.enc.encode(test_chunk)) <= max_tokens:
current_chunk = test_chunk
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sent
if current_chunk:
chunks.append(current_chunk.strip())
# Filter out chunks that are too short
chunks = [chunk for chunk in chunks if len(chunk) >= MINIMUM_CHUNK_SIZE]
return chunks

def _transform(self, inputs: dict, **kwargs) -> dict:
"""
Extract embeddings from text using the specified model.
Expand All @@ -219,41 +252,17 @@ def _transform(self, inputs: dict, **kwargs) -> dict:

results = {}
text_source = TEXT_MAPPING[self.text_source]
enc = tiktoken.encoding_for_model(self.extraction_model)

def chunk_paragraph(paragraph, max_tokens=MAX_TOKENS):
tokens = enc.encode(paragraph)
if len(tokens) <= max_tokens and len(tokens) >= MINIMUM_CHUNK_SIZE:
return [paragraph]
if len(tokens) < MINIMUM_CHUNK_SIZE:
return []
# Use spaCy to split into sentences
if not self._nlp_initialized:
self._init_nlp_components()
doc = self._nlp(paragraph)
sentences = [sent.text for sent in doc.sents]
chunks = []
current_chunk = ""
for sent in sentences:
test_chunk = current_chunk + " " + sent if current_chunk else sent
if len(enc.encode(test_chunk)) <= max_tokens:
current_chunk = test_chunk
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sent
if current_chunk:
chunks.append(current_chunk.strip())
# Filter out chunks that are too short
chunks = [chunk for chunk in chunks if len(chunk) >= MINIMUM_CHUNK_SIZE]
return chunks

for study_id, study_inputs in inputs.items():
text = get_nested_value(study_inputs, text_source)
if not text:
logging.warning(f"No text found for study {study_id}")
results[study_id] = {"embedding": []}
continue
paragraphs = text.split("\n\n")
all_chunks = []
for para in paragraphs:
all_chunks.extend(chunk_paragraph(para, MAX_TOKENS))
all_chunks.extend(self.chunk_paragraph(para, MAX_TOKENS))
embeddings = []
for chunk in all_chunks:
embedding_response = self.client.embeddings.create(
Expand Down
8 changes: 6 additions & 2 deletions ns_extract/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,14 @@ def __identify_matching_results(
# Get existing input file hashes for this study
existing = existing_results.get(dbid, {}).get("inputs", {})

# Skip if no existing results or no current inputs
if not existing or dbid not in study_inputs:
# Skip if no existing results and there are current inputs
if not existing and dbid in study_inputs and study_inputs[dbid]:
result_matches[dbid] = False
continue
elif not study_inputs[dbid]:
# no current inputs, so results are matching
result_matches[dbid] = True
continue

# Use __are_file_hashes_identical to compare hashes
result_matches[dbid] = self.__do_file_hashes_match(
Expand Down
2 changes: 1 addition & 1 deletion ns_extract/pipelines/semantic_embeddings/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ class GeneralAPIEmbeddingExtractor(APIEmbeddingExtractor, IndependentPipeline):

_version = "1.0.0"
_output_schema = EmbeddingSchema
_data_pond_inputs = {("pubget", "ace"): ("text", "metadata")}
_data_pond_inputs = {("pubget", "ace", "db"): ("text", "metadata")}
_pipeline_inputs = {}
12 changes: 9 additions & 3 deletions ns_extract/pipelines/tfidf/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from pydantic import BaseModel, Field
from typing import Dict, List, Literal, Optional
from ns_extract.pipelines.base import DependentPipeline, Extractor
Expand Down Expand Up @@ -89,10 +91,14 @@ def _transform(self, inputs: dict, **kwargs) -> dict:
# Process all texts
study_texts = {}
for study_id, study_inputs in inputs.items():
text = study_inputs["text"]
metadata = study_inputs["metadata"]
text = study_inputs.get("text", "")
metadata = study_inputs.get("metadata", {})
content = self.get_text_content(text, metadata)
study_texts[study_id] = content
if not content:
logging.warning(f"No content found for study {study_id}")
study_texts[study_id] = ""
else:
study_texts[study_id] = content

# Get list of all texts in same order as study IDs
study_ids = list(study_texts.keys())
Expand Down
Loading