Skip to content

Commit ce1cd6d

Browse files
Support pandas' iloc indexer (huggingface#191)
1 parent 77f656c commit ce1cd6d

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/smolagents/local_python_executor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,9 @@ def evaluate_subscript(
685685
if isinstance(value, pd.core.indexing._LocIndexer):
686686
parent_object = value.obj
687687
return parent_object.loc[index]
688+
if isinstance(value, pd.core.indexing._iLocIndexer):
689+
parent_object = value.obj
690+
return parent_object.iloc[index]
688691
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
689692
return value[index]
690693
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):

tests/test_python_interpreter.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ def test_pandas(self):
808808
)
809809
assert np.array_equal(result.values[0], [104, 1])
810810

811+
# Test groupby
811812
code = """import pandas as pd
812813
data = pd.DataFrame.from_dict([
813814
{"Pclass": 1, "Survived": 1},
@@ -821,6 +822,21 @@ def test_pandas(self):
821822
)
822823
assert result.values[1] == 0.5
823824

825+
# Test loc and iloc
826+
code = """import pandas as pd
827+
data = pd.DataFrame.from_dict([
828+
{"Pclass": 1, "Survived": 1},
829+
{"Pclass": 2, "Survived": 0},
830+
{"Pclass": 2, "Survived": 1}
831+
])
832+
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
833+
survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
834+
survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
835+
"""
836+
result, _ = evaluate_python_code(
837+
code, {}, state={}, authorized_imports=["pandas"]
838+
)
839+
824840
def test_starred(self):
825841
code = """
826842
from math import radians, sin, cos, sqrt, atan2

0 commit comments

Comments
 (0)