Skip to content

Commit c1ca373

Browse files
authored
Refactor DNA translation functions to use PyBytes
1 parent 6ff9c1e commit c1ca373

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

src/python_api.rs

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#![allow(clippy::borrow_deref_ref)] // TODO: broken clippy lint?
2-
// Copyright 2021-2024 SecureDNA Stiftung (SecureDNA Foundation) <licensing@securedna.org>
3-
// SPDX-License-Identifier: MIT OR Apache-2.0
2+
// Copyright 2021-2024 SecureDNA Stiftung (SecureDNA Foundation) <licensing@securedna.org>
3+
// SPDX-License-Identifier: MIT OR Apache-2.0
44

5-
use pyo3::{exceptions::PyValueError, prelude::*, types::{PyBytes, PyAny}};
5+
use pyo3::{exceptions::PyValueError, prelude::*, types::PyBytes};
66

77
use crate::{
88
errors::TranslationError,
@@ -29,11 +29,10 @@ fn _check_table(table: u8) -> PyResult<()> {
2929
///
3030
/// * `translate(b"CCNTACACK CATNCNAAT")` returns `b"PYTHXN"`
3131
#[pyfunction]
32-
fn _translate(py: Python, table: u8, dna: &PyAny) -> PyResult<Py<PyAny>> {
32+
fn _translate(py: Python, table: u8, dna: &PyBytes) -> PyResult<Py<PyAny>> {
3333
let table = TranslationTable::try_from(table)?;
34-
let dna_bytes = dna.downcast::<PyBytes>()?;
35-
let bytes = table.translate_dna_bytes::<NucleotideAmbiguous>(dna_bytes.as_bytes())?;
36-
Ok(PyBytes::new(py, &bytes).into_py(py))
34+
let bytes = table.translate_dna_bytes::<NucleotideAmbiguous>(dna.as_bytes())?;
35+
Ok(PyBytes::new(py, &bytes).into())
3736
}
3837

3938
/// Translate a bytestring of DNA nucleotides into a bytestring of amino acids.
@@ -43,11 +42,10 @@ fn _translate(py: Python, table: u8, dna: &PyAny) -> PyResult<Py<PyAny>> {
4342
/// * `translate_strict(b"AAACCCTTTGGG")` returns `b"KPFG"`
4443
/// * `translate_strict(b"AAACCCTTTGGN")` is an error.
4544
#[pyfunction]
46-
fn _translate_strict(py: Python, table: u8, dna: &PyAny) -> PyResult<Py<PyAny>> {
45+
fn _translate_strict(py: Python, table: u8, dna: &PyBytes) -> PyResult<Py<PyAny>> {
4746
let table = TranslationTable::try_from(table)?;
48-
let dna_bytes = dna.downcast::<PyBytes>()?;
49-
let bytes = table.translate_dna_bytes::<Nucleotide>(dna_bytes.as_bytes())?;
50-
Ok(PyBytes::new(py, &bytes).into_py(py))
47+
let bytes = table.translate_dna_bytes::<Nucleotide>(dna.as_bytes())?;
48+
Ok(PyBytes::new(py, &bytes).into())
5149
}
5250

5351
/// Get the reverse complement of a bytestring of DNA nucleotides.
@@ -56,10 +54,9 @@ fn _translate_strict(py: Python, table: u8, dna: &PyAny) -> PyResult<Py<PyAny>>
5654
///
5755
/// * `reverse_complement(b"AAAAABCCC")` returns `b"GGGVTTTTT"`
5856
#[pyfunction]
59-
fn _reverse_complement(py: Python, dna: &PyAny) -> PyResult<Py<PyAny>> {
60-
let dna_bytes = dna.downcast::<PyBytes>()?;
61-
let bytes = reverse_complement_bytes::<NucleotideAmbiguous>(dna_bytes.as_bytes())?;
62-
Ok(PyBytes::new(py, &bytes).into_py(py))
57+
fn _reverse_complement(py: Python, dna: &PyBytes) -> PyResult<Py<PyAny>> {
58+
let bytes = reverse_complement_bytes::<NucleotideAmbiguous>(dna.as_bytes())?;
59+
Ok(PyBytes::new(py, &bytes).into())
6360
}
6461

6562
/// Get the reverse complement of a bytestring of DNA nucleotides.
@@ -69,14 +66,13 @@ fn _reverse_complement(py: Python, dna: &PyAny) -> PyResult<Py<PyAny>> {
6966
/// * `reverse_complement_strict(b"AAAAAACCC")` returns `b"GGGTTTTTT"`
7067
/// * `reverse_complement_strict(b"AAAAAACCN")` is an error.
7168
#[pyfunction]
72-
fn _reverse_complement_strict(py: Python, dna: &PyAny) -> PyResult<Py<PyAny>> {
73-
let dna_bytes = dna.downcast::<PyBytes>()?;
74-
let bytes = reverse_complement_bytes::<Nucleotide>(dna_bytes.as_bytes())?;
75-
Ok(PyBytes::new(py, &bytes).into_py(py))
69+
fn _reverse_complement_strict(py: Python, dna: &PyBytes) -> PyResult<Py<PyAny>> {
70+
let bytes = reverse_complement_bytes::<Nucleotide>(dna.as_bytes())?;
71+
Ok(PyBytes::new(py, &bytes).into())
7672
}
7773

7874
#[pymodule]
79-
fn quickdna(py: Python, m: &PyModule) -> PyResult<()> {
75+
fn quickdna(_py: Python, m: &PyModule) -> PyResult<()> {
8076
m.add_function(wrap_pyfunction!(_check_table, m)?)?;
8177
m.add_function(wrap_pyfunction!(_translate, m)?)?;
8278
m.add_function(wrap_pyfunction!(_translate_strict, m)?)?;

0 commit comments

Comments
 (0)