Skip to content

Commit 3e1bca5

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 6fdec3c commit 3e1bca5

File tree

3 files changed

+36
-27
lines changed

3 files changed

+36
-27
lines changed

light-curve/src/arrow_input.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ pub(crate) struct ArrowLcsSchema {
4444
pub list_type: ArrowListType,
4545
}
4646

47-
fn resolve_field(fields: &arrow_schema::Fields, field_ref: &ArrowFieldRef, role: &str) -> Res<usize> {
47+
fn resolve_field(
48+
fields: &arrow_schema::Fields,
49+
field_ref: &ArrowFieldRef,
50+
role: &str,
51+
) -> Res<usize> {
4852
match field_ref {
4953
ArrowFieldRef::Index(i) => {
5054
if *i >= fields.len() {
@@ -56,15 +60,14 @@ fn resolve_field(fields: &arrow_schema::Fields, field_ref: &ArrowFieldRef, role:
5660
Ok(*i)
5761
}
5862
}
59-
ArrowFieldRef::Name(name) => fields
60-
.iter()
61-
.position(|f| f.name() == name)
62-
.ok_or_else(|| {
63+
ArrowFieldRef::Name(name) => {
64+
fields.iter().position(|f| f.name() == name).ok_or_else(|| {
6365
Exception::ValueError(format!(
6466
"arrow_fields: field name {name:?} for {role} not found in struct fields {:?}",
6567
fields.iter().map(|f| f.name()).collect::<Vec<_>>()
6668
))
67-
}),
69+
})
70+
}
6871
}
6972
}
7073

light-curve/src/features.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,11 @@ impl PyFeatureEvaluator {
623623
"Null entries in the struct array are not supported".to_string(),
624624
));
625625
}
626-
for &col_idx in &[Some(t_idx), Some(m_idx), sigma_idx].into_iter().flatten().collect::<Vec<_>>() {
626+
for &col_idx in &[Some(t_idx), Some(m_idx), sigma_idx]
627+
.into_iter()
628+
.flatten()
629+
.collect::<Vec<_>>()
630+
{
627631
if struct_arr.column(col_idx).null_count() > 0 {
628632
return Err(Exception::NotImplementedError(
629633
"Null values in data columns are not supported".to_string(),
@@ -677,9 +681,12 @@ impl PyFeatureEvaluator {
677681
}
678682
}
679683

680-
681-
fn parse_arrow_fields(arrow_fields: Option<&Bound<pyo3::types::PyList>>) -> Res<Option<Vec<ArrowFieldRef>>> {
682-
let Some(list) = arrow_fields else { return Ok(None) };
684+
fn parse_arrow_fields(
685+
arrow_fields: Option<&Bound<pyo3::types::PyList>>,
686+
) -> Res<Option<Vec<ArrowFieldRef>>> {
687+
let Some(list) = arrow_fields else {
688+
return Ok(None);
689+
};
683690
let refs = list
684691
.iter()
685692
.map(|item| {
@@ -689,7 +696,8 @@ fn parse_arrow_fields(arrow_fields: Option<&Bound<pyo3::types::PyList>>) -> Res<
689696
Ok(ArrowFieldRef::Name(name))
690697
} else {
691698
Err(Exception::TypeError(
692-
"arrow_fields elements must be integers (indices) or strings (field names)".to_string(),
699+
"arrow_fields elements must be integers (indices) or strings (field names)"
700+
.to_string(),
693701
))
694702
}
695703
})
@@ -804,7 +812,15 @@ impl PyFeatureEvaluator {
804812
// Try Arrow path first
805813
if lcs.hasattr("__arrow_c_array__")? || lcs.hasattr("__arrow_c_stream__")? {
806814
let parsed = parse_arrow_fields(arrow_fields.as_ref())?;
807-
return self.many_arrow(py, &lcs, fill_value, sorted, check, n_jobs, parsed.as_deref());
815+
return self.many_arrow(
816+
py,
817+
&lcs,
818+
fill_value,
819+
sorted,
820+
check,
821+
n_jobs,
822+
parsed.as_deref(),
823+
);
808824
}
809825
// Fall back to list-of-tuples path
810826
let lcs: PyLcs<'py> = lcs.extract()?;

light-curve/tests/light_curve_ext/test_feature.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -855,29 +855,23 @@ def test_many_arrow_fields_nondefault_order():
855855
def test_many_arrow_fields_invalid_name():
856856
"""arrow_fields with a nonexistent field name raises ValueError."""
857857
feature = lc.Amplitude()
858-
arr = _make_arrow_lcs(
859-
[gen_lc(10, rng=np.random.default_rng(5)) for _ in range(2)]
860-
)
858+
arr = _make_arrow_lcs([gen_lc(10, rng=np.random.default_rng(5)) for _ in range(2)])
861859
with pytest.raises(ValueError, match="not found"):
862860
feature.many(arr, sorted=True, arrow_fields=["t", "nonexistent"])
863861

864862

865863
def test_many_arrow_fields_index_out_of_range():
866864
"""arrow_fields with an out-of-range index raises ValueError."""
867865
feature = lc.Amplitude()
868-
arr = _make_arrow_lcs(
869-
[gen_lc(10, rng=np.random.default_rng(6)) for _ in range(2)]
870-
)
866+
arr = _make_arrow_lcs([gen_lc(10, rng=np.random.default_rng(6)) for _ in range(2)])
871867
with pytest.raises(ValueError, match="out of range"):
872868
feature.many(arr, sorted=True, arrow_fields=[0, 99])
873869

874870

875871
def test_many_arrow_fields_wrong_count():
876872
"""arrow_fields with 1 or 4 elements raises ValueError."""
877873
feature = lc.Amplitude()
878-
arr = _make_arrow_lcs(
879-
[gen_lc(10, rng=np.random.default_rng(7)) for _ in range(2)]
880-
)
874+
arr = _make_arrow_lcs([gen_lc(10, rng=np.random.default_rng(7)) for _ in range(2)])
881875
with pytest.raises(ValueError, match="2 .* or 3"):
882876
feature.many(arr, sorted=True, arrow_fields=["t"])
883877
with pytest.raises(ValueError, match="2 .* or 3"):
@@ -887,18 +881,14 @@ def test_many_arrow_fields_wrong_count():
887881
def test_many_arrow_fields_duplicate_fields():
888882
"""arrow_fields pointing to the same field twice raises ValueError."""
889883
feature = lc.Amplitude()
890-
arr = _make_arrow_lcs(
891-
[gen_lc(10, rng=np.random.default_rng(8)) for _ in range(2)]
892-
)
884+
arr = _make_arrow_lcs([gen_lc(10, rng=np.random.default_rng(8)) for _ in range(2)])
893885
with pytest.raises(ValueError, match="different fields"):
894886
feature.many(arr, sorted=True, arrow_fields=["t", "t"])
895887

896888

897889
def test_many_arrow_fields_wrong_element_type():
898890
"""Non-str, non-int element in arrow_fields raises TypeError."""
899891
feature = lc.Amplitude()
900-
arr = _make_arrow_lcs(
901-
[gen_lc(10, rng=np.random.default_rng(9)) for _ in range(2)]
902-
)
892+
arr = _make_arrow_lcs([gen_lc(10, rng=np.random.default_rng(9)) for _ in range(2)])
903893
with pytest.raises(TypeError, match="integers.*strings|strings.*integers"):
904894
feature.many(arr, sorted=True, arrow_fields=["t", 1.5])

0 commit comments

Comments
 (0)