Skip to content

Commit 821c8b0

Browse files
committed
Merge branch 'dev' into production
2 parents 1faf578 + f5f20fb commit 821c8b0

File tree

14 files changed

+556
-49
lines changed

14 files changed

+556
-49
lines changed

.dependabot/config.yml

Lines changed: 0 additions & 6 deletions
This file was deleted.

.github/dependabot.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
version: 2
2+
updates:
3+
- package-ecosystem: pip
4+
directory: "/"
5+
schedule:
6+
interval: daily
7+
open-pull-requests-limit: 10
8+
target-branch: dev
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
4+
name: Python package
5+
6+
on:
7+
push:
8+
branches: [ dev ]
9+
pull_request:
10+
branches: [ dev ]
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: [3.8, 3.9]
19+
requirements-cmd:
20+
- -r testing-requirements.txt -r requirements.txt
21+
pytest-run-cmd:
22+
- pytest
23+
exclude:
24+
- python-version: 3.9
25+
requirements-cmd: -r testing-requirements.txt -r requirements.txt
26+
pytest-run-cmd: pytest
27+
include:
28+
- python-version: 3.8
29+
requirements-cmd: -r testing-requirements.txt -r oldest
30+
- python-version: 3.9
31+
requirements-cmd: -r testing-requirements.txt -r coverage-requirements.txt -r requirements.txt
32+
pytest-run-cmd: |
33+
pytest --cov=generalizedtrees --cov-report xml
34+
codecov
35+
36+
steps:
37+
- uses: actions/checkout@v2
38+
- name: Set up Python ${{ matrix.python-version }}
39+
uses: actions/setup-python@v2
40+
with:
41+
python-version: ${{ matrix.python-version }}
42+
- name: Install dependencies
43+
run: |
44+
python -m pip install --upgrade pip
45+
python -m pip install flake8 pytest
46+
python -m pip install ${{ matrix.requirements-cmd }}
47+
- name: Lint with flake8
48+
run: |
49+
# stop the build if there are Python syntax errors or undefined names
50+
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
51+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
52+
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
53+
- name: Test with pytest
54+
run: |
55+
${{ matrix.pytest-run-cmd }}

.travis.yml

Lines changed: 0 additions & 20 deletions
This file was deleted.

coverage-requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
pytest-cov==2.10.1
1+
pytest-cov==2.11.1
22
codecov==2.1.11

generalizedtrees/constraints.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class MofN(Constraint):
122122
class SearchOperator(Flag):
123123
INC_M = auto()
124124
INC_N = auto()
125+
DEC_M = auto()
126+
DEC_N = auto()
125127
INC_NM = INC_N | INC_M
126128

127129
def __init__(self, m: int, constraints):
@@ -208,15 +210,19 @@ def neighboring_tests(
208210

209211
for operator in search_operators:
210212

211-
new_m = constraint.m_to_satisfy + \
212-
(1 if operator & MofN.SearchOperator.INC_M else 0)
213+
new_m = constraint.m_to_satisfy + (
214+
1 if operator & MofN.SearchOperator.INC_M else (
215+
-1 if operator & MofN.SearchOperator.DEC_M else 0))
213216

214217
if operator & MofN.SearchOperator.INC_N:
215218
for atom in constraint_candidates:
216219
new_atoms = constraint.constraints + (atom,)
217-
218220
yield MofN(new_m, new_atoms)
219-
221+
222+
elif operator & MofN.SearchOperator.DEC_N:
223+
atoms = list(constraint.constraints)
224+
for atom in atoms:
225+
yield MofN(new_m, tuple(atoms - {atom}))
220226
else:
221227
yield MofN(new_m, constraint.constraints)
222228

generalizedtrees/givens.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,25 @@ class DataWithOracleGivensLC(GivensLC):
110110
a predictor oracle.
111111
"""
112112

113-
def process(self, data, oracle, *args, target_names=None, prelabel_data=True, **kwargs):
113+
def process(self, data, oracle, *args, target_names=None, prelabel_data=True, feature_groups=None, **kwargs):
114+
"""
115+
Processing input for the explanation setting:
116+
117+
:param: data - the n-by-d feature matrix
118+
:param: oracle - the oracle, can either be a SKLearn-like classifier, in which case the predict function is
119+
used, or a function that takes data and outputs a prediction (in matrix form).
120+
:param: target_names - The names of the target classes/components.
121+
:param: prelabel_data - Whether to run the oracle on the data first, before building the model
122+
:param: feature_groups - Used in some explanation configurations - A list of lists of feature indeces
123+
representing semantically meaningful feature groups
124+
"""
114125

115126
if target_names is not None:
116127
self.target_names = target_names
128+
129+
if feature_groups is not None:
130+
# TODO: Validation
131+
self.feature_groups = feature_groups
117132

118133
# Parse data
119134
self.data_matrix, self.feature_names, self.feature_spec = parse_data(
@@ -127,7 +142,7 @@ def process(self, data, oracle, *args, target_names=None, prelabel_data=True, **
127142
logger.info(
128143
'Inferring that oracle is a Scikit-Learn-like classifier '
129144
'and using the "predict" method.')
130-
self.oracle = lambda x: np.eye(len(oracle.classes_))[oracle.predict(x),]
145+
self.oracle = lambda x: np.eye(len(oracle.classes_))[oracle.predict(x).astype('intp'),]
131146
else:
132147
logger.info('Treating oracle as a function')
133148
self.oracle = oracle

generalizedtrees/grow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,10 @@ class GreedyBuilderLC:
192192

193193
def __init__(self):
194194
self.splitter = DefaultSplitConstructorLC()
195+
196+
def initialize(self, givens: GivensLC) -> None:
197+
self.node_builder.initialize(givens)
198+
self.splitter.initialize(givens)
195199

196200
def build_tree(self) -> Tree:
197201

generalizedtrees/learn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ def fit(self, *args, **kwargs):
105105

106106
# Set components
107107
self.predictor.initialize(self.givens)
108-
self.builder.node_builder.initialize(self.givens)
109-
self.split_generator.initialize(self.givens)
108+
self.builder.initialize(self.givens)
110109

111110
# Build tree
112111
self.tree = self.builder.build_tree()

0 commit comments

Comments
 (0)