Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions src/atomworks/io/common.py → src/atomworks/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations
"""Common functions used throughout the project."""

import copy
import hashlib
from collections import OrderedDict
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable
from functools import lru_cache, wraps
from typing import Any

Expand All @@ -12,26 +11,29 @@


def exists(obj: Any) -> bool:
"""Check that `obj` is not `None`."""
return obj is not None


def default(obj: Any, default: Any) -> Any:
"""Return `obj` if not `None`, otherwise return `default`."""
return obj if exists(obj) else default


def deduplicate_iterator(iterator: Iterable) -> Iterator:
"""Deduplicate an iterator while preserving order."""
return iter(OrderedDict.fromkeys(iterator))


def to_hashable(element: Any) -> Any:
"""Convert an element to a hashable type."""
return element if isinstance(element, int | str | np.integer | np.str_) else tuple(element)


def string_to_md5_hash(s: str, truncate: int = 32) -> str:
"""Generate an MD5 hash of a string and return the first `truncate` characters."""
full_hash = hashlib.md5(s.encode("utf-8")).hexdigest()
return full_hash[:truncate]


def sum_string_arrays(*objs: np.ndarray | str) -> np.ndarray:
"""
Sum a list of string arrays / strings into a single string array by concatenating them and
Sum a list of string arrays or strings into a single string array by concatenating them and
determining the shortest string length to set as dtype.
"""
return reduce(np.char.add, objs).astype(object).astype(str)
Expand All @@ -47,6 +49,24 @@ def listmap(func: Callable, *iterables) -> list:
return compose(list, map)(func, *iterables)


def as_list(value: Any) -> list:
"""Convert a value to a list.

Handles various types using duck typing:
- Iterable objects (lists, tuples, strings, etc.): converted to list
- Single values: wrapped in a list
"""
try:
# Try to iterate over the value (duck typing approach)
# Exclude strings since they're iterable but we want to treat them as single values
if isinstance(value, str):
return [value]
return list(value)
except TypeError:
# If it's not iterable, wrap it in a list
return [value]


def immutable_lru_cache(maxsize: int = 128, typed: bool = False, deepcopy: bool = True) -> Callable:
"""An immutable version of `lru_cache` for caching functions that return mutable objects."""
copy_func = copy.deepcopy if deepcopy else copy.copy
Expand Down Expand Up @@ -89,9 +109,3 @@ def __call__(self, value: Any) -> int:
self.key_to_id[value] = self.next_id
self.next_id += 1
return self.key_to_id[value]


def md5_hash_string(s: str, length: int = 32) -> str:
"""Generate an MD5 hash of a string and return the first `length` characters."""
full_hash = hashlib.md5(s.encode("utf-8")).hexdigest()
return full_hash[:length]
File renamed without changes.
2 changes: 1 addition & 1 deletion src/atomworks/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from toolz import keymap

from atomworks.io.constants import (
from atomworks.constants import (
AA_LIKE_CHEM_TYPES,
DNA_LIKE_CHEM_TYPES,
POLYPEPTIDE_D_CHEM_TYPES,
Expand Down
6 changes: 3 additions & 3 deletions src/atomworks/io/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from toolz import keyfilter

import atomworks.io.transforms.atom_array as ta
from atomworks.common import exists, string_to_md5_hash
from atomworks.constants import CCD_MIRROR_PATH, CRYSTALLIZATION_AIDS, WATER_LIKE_CCDS
from atomworks.io import template
from atomworks.io.common import exists, md5_hash_string
from atomworks.io.constants import CCD_MIRROR_PATH, CRYSTALLIZATION_AIDS, WATER_LIKE_CCDS
from atomworks.io.transforms.categories import (
category_to_dict,
extract_crystallization_details,
Expand Down Expand Up @@ -218,7 +218,7 @@ def parse(
}
# Compose args_string from parse_arguments values (in order)
args_string = ",".join(str(parse_arguments[k]) for k in parse_arguments)
args_hash = md5_hash_string(args_string, length=8)
args_hash = string_to_md5_hash(args_string, truncate=8)

# ... generate assembly info
assembly_info = ",".join(build_assembly) if isinstance(build_assembly, list | tuple) else build_assembly
Expand Down
4 changes: 2 additions & 2 deletions src/atomworks/io/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from biotite.structure import AtomArray, BondList

import atomworks.io.transforms.atom_array as ta
from atomworks.io.common import exists, immutable_lru_cache
from atomworks.io.constants import CCD_MIRROR_PATH, DO_NOT_MATCH_CCD
from atomworks.common import exists, immutable_lru_cache
from atomworks.constants import CCD_MIRROR_PATH, DO_NOT_MATCH_CCD
from atomworks.io.utils.bonds import (
correct_bond_types_for_nucleophilic_additions,
correct_formal_charges_for_specified_atoms,
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/io/tools/fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import os
import re

from atomworks.constants import CCD_MIRROR_PATH
from atomworks.enums import ChainType
from atomworks.io.constants import CCD_MIRROR_PATH
from atomworks.io.utils.ccd import (
check_ccd_codes_are_available,
)
Expand Down
8 changes: 4 additions & 4 deletions src/atomworks/io/tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
from rdkit.Chem import AllChem

import atomworks.io.transforms.atom_array as ta
from atomworks.enums import ChainType, ChainTypeInfo
from atomworks.io import parse
from atomworks.io.common import KeyToIntMapper, exists
from atomworks.io.constants import (
from atomworks.common import KeyToIntMapper, exists
from atomworks.constants import (
CCD_MIRROR_PATH,
STANDARD_AA_ONE_LETTER,
STANDARD_DNA_ONE_LETTER,
STANDARD_RNA,
UNKNOWN_LIGAND,
)
from atomworks.enums import ChainType, ChainTypeInfo
from atomworks.io import parse
from atomworks.io.parser import DEFAULT_PARSE_KWARGS
from atomworks.io.template import build_template_atom_array
from atomworks.io.tools.fasta import one_letter_to_ccd_code, split_generalized_fasta_sequence
Expand Down
4 changes: 2 additions & 2 deletions src/atomworks/io/tools/rdkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from rdkit.DataStructs import ExplicitBitVect

import atomworks.io.transforms.atom_array as ta
from atomworks.io.common import exists, immutable_lru_cache, not_isin
from atomworks.io.constants import (
from atomworks.common import exists, immutable_lru_cache, not_isin
from atomworks.constants import (
BIOTITE_DEFAULT_ANNOTATIONS,
CCD_MIRROR_PATH,
HYDROGEN_LIKE_SYMBOLS,
Expand Down
4 changes: 2 additions & 2 deletions src/atomworks/io/transforms/atom_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import pandas as pd
from biotite.structure import AtomArray, AtomArrayStack, stack

from atomworks.io.common import listmap, not_isin, sum_string_arrays
from atomworks.io.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER, HYDROGEN_LIKE_SYMBOLS, WATER_LIKE_CCDS
from atomworks.common import listmap, not_isin, sum_string_arrays
from atomworks.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER, HYDROGEN_LIKE_SYMBOLS, WATER_LIKE_CCDS
from atomworks.io.utils.bonds import (
generate_inter_level_bond_hash,
get_coarse_graph_as_nodes_and_edges,
Expand Down
8 changes: 5 additions & 3 deletions src/atomworks/io/transforms/categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from biotite.structure import AtomArray
from biotite.structure.io.pdbx import CIFBlock

from atomworks.common import exists
from atomworks.constants import CCD_MIRROR_PATH
from atomworks.enums import ChainType
from atomworks.io.common import deduplicate_iterator, exists
from atomworks.io.constants import CCD_MIRROR_PATH
from atomworks.io.utils.selection import get_residue_starts
from atomworks.io.utils.sequence import get_1_from_3_letter_code

Expand Down Expand Up @@ -253,7 +253,9 @@ def load_monomer_sequence_information_from_category(

# Build up the chain_info_dict with the sequence information
res_starts = get_residue_starts(atom_array)
for chain_id in deduplicate_iterator(struc.get_chains(atom_array)):
# ... get the unique chain IDs by order of first appearance in the AtomArray
chain_ids = dict.fromkeys(struc.get_chains(atom_array))
for chain_id in chain_ids:
rcsb_entity = int(chain_info_dict[chain_id]["rcsb_entity"])

if rcsb_entity in polymer_entity_id_to_res_names_and_ids:
Expand Down
6 changes: 3 additions & 3 deletions src/atomworks/io/utils/bonds.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@
_get_struct_conn_col_name,
)

from atomworks.enums import ChainType, ChainTypeInfo
from atomworks.io.common import sum_string_arrays, to_hashable
from atomworks.io.constants import (
from atomworks.common import sum_string_arrays, to_hashable
from atomworks.constants import (
AA_LIKE_CHEM_TYPES,
CHEM_TYPE_POLYMERIZATION_ATOMS,
DEFAULT_VALENCE,
Expand All @@ -39,6 +38,7 @@
STRUCT_CONN_BOND_ORDER_TO_INT,
STRUCT_CONN_BOND_TYPES,
)
from atomworks.enums import ChainType, ChainTypeInfo
from atomworks.io.utils.ccd import get_chem_comp_leaving_atom_names, get_chem_comp_type
from atomworks.io.utils.selection import get_annotation, get_residue_starts
from atomworks.io.utils.testing import has_ambiguous_annotation_set
Expand Down
6 changes: 3 additions & 3 deletions src/atomworks/io/utils/ccd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
import numpy as np
import toolz

from atomworks.enums import ChainType, ChainTypeInfo
from atomworks.io.common import exists, immutable_lru_cache
from atomworks.io.constants import (
from atomworks.common import exists, immutable_lru_cache
from atomworks.constants import (
AA_LIKE_CHEM_TYPES,
CCD_MIRROR_PATH,
DNA_LIKE_CHEM_TYPES,
Expand All @@ -25,6 +24,7 @@
UNKNOWN_LIGAND,
UNKNOWN_RNA,
)
from atomworks.enums import ChainType, ChainTypeInfo

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions src/atomworks/io/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from biotite.structure.io import mol, pdbx

import atomworks.io.transforms.atom_array as ta # to avoid circular import
from atomworks.common import exists
from atomworks.constants import ATOMIC_NUMBER_TO_ELEMENT, STANDARD_AA, STANDARD_DNA, STANDARD_RNA
from atomworks.enums import ChainType
from atomworks.io.common import exists
from atomworks.io.constants import ATOMIC_NUMBER_TO_ELEMENT, STANDARD_AA, STANDARD_DNA, STANDARD_RNA
from atomworks.io.template import add_inter_residue_bonds
from atomworks.io.transforms.categories import category_to_dict
from atomworks.io.utils.selection import get_annotation
Expand Down
4 changes: 2 additions & 2 deletions src/atomworks/io/utils/non_rcsb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from biotite.structure import AtomArray
from biotite.structure.io.pdbx import CIFCategory

from atomworks.enums import ChainType
from atomworks.io.constants import (
from atomworks.constants import (
AA_LIKE_CHEM_TYPES,
DNA_LIKE_CHEM_TYPES,
POLYPEPTIDE_D_CHEM_TYPES,
POLYPEPTIDE_L_CHEM_TYPES,
RNA_LIKE_CHEM_TYPES,
)
from atomworks.enums import ChainType
from atomworks.io.utils.ccd import get_chem_comp_type
from atomworks.io.utils.selection import get_residue_starts
from atomworks.io.utils.sequence import get_1_from_3_letter_code
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/io/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from biotite.structure import AtomArray, AtomArrayStack

from atomworks.io.common import not_isin
from atomworks.common import not_isin
from atomworks.io.transforms.atom_array import is_any_coord_nan


Expand Down
4 changes: 2 additions & 2 deletions src/atomworks/io/utils/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
import numpy as np
import toolz

from atomworks.enums import ChainType
from atomworks.io.constants import (
from atomworks.constants import (
GAP,
GAP_ONE_LETTER,
STANDARD_AA,
Expand All @@ -25,6 +24,7 @@
UNKNOWN_DNA,
UNKNOWN_RNA,
)
from atomworks.enums import ChainType
from atomworks.io.utils.ccd import (
aa_chem_comps,
chem_comp_to_one_letter,
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/io/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from biotite.structure.atoms import AtomArray, AtomArrayStack

import atomworks.io.utils.bonds as cb
from atomworks.io.constants import PDB_MIRROR_PATH
from atomworks.constants import PDB_MIRROR_PATH
from atomworks.io.utils.scatter import apply_group_wise, apply_segment_wise


Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/io/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from biotite.structure import AtomArray, AtomArrayStack
from biotite.structure.io import mol, pdb, pdbx

from atomworks.io.constants import ATOMIC_NUMBER_TO_ELEMENT, METAL_ELEMENTS
from atomworks.constants import ATOMIC_NUMBER_TO_ELEMENT, METAL_ELEMENTS
from atomworks.io.utils.io_utils import read_any, to_cif_string

logger = logging.getLogger("atomworks.io")
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/ml/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas as pd
from torch.utils.data import ConcatDataset, Dataset

from atomworks.ml.common import default, exists
from atomworks.common import default, exists
from atomworks.ml.datasets import logger
from atomworks.ml.datasets.parsers import MetadataRowParser, load_example_from_metadata_row
from atomworks.ml.preprocessing.constants import NA_VALUES
Expand Down
2 changes: 1 addition & 1 deletion src/atomworks/ml/datasets/parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import pandas as pd

from atomworks.constants import CRYSTALLIZATION_AIDS
from atomworks.io import parse
from atomworks.io.constants import CRYSTALLIZATION_AIDS

DEFAULT_CIF_PARSER_ARGS = {
"add_missing_atoms": True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas as pd

from atomworks.io.constants import PDB_MIRROR_PATH
from atomworks.constants import PDB_MIRROR_PATH
from atomworks.ml.datasets.parsers import MetadataRowParser


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import pandas as pd

from atomworks.io.constants import PDB_MIRROR_PATH
from atomworks.ml.common import as_list
from atomworks.common import as_list
from atomworks.constants import PDB_MIRROR_PATH
from atomworks.ml.datasets.parsers import MetadataRowParser


Expand Down
4 changes: 2 additions & 2 deletions src/atomworks/ml/encoding_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import biotite.structure as struc
import numpy as np

from atomworks.io.constants import (
from atomworks.common import exists
from atomworks.constants import (
AA_LIKE_CHEM_TYPES,
CHEM_COMP_TYPES,
DNA_LIKE_CHEM_TYPES,
Expand All @@ -25,7 +26,6 @@
UNKNOWN_RNA,
)
from atomworks.io.utils.ccd import get_chem_comp_type
from atomworks.ml.common import exists

logger = getLogger(__name__)

Expand Down
Loading
Loading