Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions biothings/hub/datatransform/datatransform_mdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def __init__(self, graph, *args, **kwargs):
source document regardless as to weather it matches an
edge or not. (advanced usage)
:type copy_from_doc: bool

Note: Prefixes can be defined at the node level using:
graph.add_node("chebi", prefix="CHEBI")
When an identifier is converted to a node with a prefix attribute,
the prefix will be automatically added to the _id.
"""
if not isinstance(graph, nx.DiGraph):
raise ValueError("key_lookup configuration error: graph must be of type nx.DiGraph")
Expand All @@ -198,6 +203,29 @@ def __init__(self, graph, *args, **kwargs):
super(DataTransformMDB, self).__init__(*args, **kwargs)
self._precompute_paths()

def _apply_prefix(self, identifier, output_type):
"""
Apply prefix to identifier based on output type.

Prefixes are defined as node attributes in the graph:
graph.add_node("chebi", prefix="CHEBI")

:param identifier: The identifier value to potentially prefix
:param output_type: The output type to check for prefix
:return: The identifier with prefix applied if configured
"""
# Check if the node has a prefix attribute
if output_type in self.graph.nodes():
node_data = self.graph.nodes[output_type]
if 'prefix' in node_data:
prefix = node_data['prefix']
identifier_str = str(identifier)
# Only add prefix if it's not already there
if not identifier_str.startswith(prefix + ":"):
return f"{prefix}:{identifier_str}"

return str(identifier)

def _valid_input_type(self, input_type):
return input_type.lower() in self.graph.nodes()

Expand Down Expand Up @@ -292,7 +320,7 @@ def key_lookup_batch(self, batchiter):
(hit_lst, miss_lst) = self.travel(input_type, output_type, miss_lst)
# or if copy is allowed, we get the value from the doc
elif self.copy_from_doc:
(hit_lst, miss_lst) = self._copy(input_type, miss_lst)
(hit_lst, miss_lst) = self._copy(input_type, output_type, miss_lst)
else:
(hit_lst, miss_lst) = self.travel(input_type, output_type, miss_lst)

Expand All @@ -305,15 +333,15 @@ def key_lookup_batch(self, batchiter):
for doc in miss_lst:
yield doc

def _copy(self, input_type, doc_lst):
def _copy(self, input_type, output_type, doc_lst):
"""Copy ids in the case where input_type == output_type"""
hit_lst = []
miss_lst = []
for doc in doc_lst:
val = nested_lookup(doc, input_type[1])
if val:
# ensure _id is always a str
doc["_id"] = str(val)
# ensure _id is always a str and apply prefix if configured
doc["_id"] = self._apply_prefix(val, output_type)
hit_lst.append(doc)
# retain debug information if available (assumed dt_debug already in place)
if self.debug:
Expand Down Expand Up @@ -371,8 +399,8 @@ def _build_hit_miss_lsts(doc_lst, id_strct, debug):
value = nested_lookup(doc, input_type[1])
for lookup_id in id_strct.find_left(value):
new_doc = copy.deepcopy(doc)
# ensure _id is always a str
new_doc["_id"] = str(lookup_id)
# ensure _id is always a str and apply prefix if configured
new_doc["_id"] = self._apply_prefix(lookup_id, target)
# capture debug information
if debug:
new_doc["dt_debug"]["start_field"] = input_type[1]
Expand Down