Skip to content

Commit df6679d

Browse files
committed
fix some review comments
1 parent a7c9562 commit df6679d

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

cebra/attribution/jacobian_attribution.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,30 @@ def get_attribution_map(
5252
convert_to_numpy: bool = True,
5353
aggregate: Literal["mean", "sum", "max"] = "mean",
5454
transform: Literal["none", "abs"] = "none",
55-
hybrid_solver=False,
55+
hybrid_solver: bool = False,
5656
):
57-
"""Estimate attribution maps.
57+
"""Estimate attribution maps using the Jacobian pseudo-inverse.
58+
5859
The function estimates Jacobian matrices for each point in the model,
59-
computes the pseudo-inverse (for every sample), applies the `transform`
60-
function point-wise, and then aggregates with the `aggregate` function
61-
over the sample dimension.
62-
The result is a `(num_inputs, num_features)` attribution map.
60+
computes the pseudo-inverse (for every sample) and then aggregates
61+
the resulting matrices to compute an attribution map.
62+
63+
Args:
64+
model: The neural network model for which to compute attributions.
65+
input_data: Input tensor or numpy array to compute attributions for.
66+
double_precision: If ``True``, use double precision for computation.
67+
convert_to_numpy: If ``True``, convert the output to numpy arrays.
68+
aggregate: Method to aggregate attribution values across samples.
69+
Options are ``"mean"``, ``"sum"``, or ``"max"``.
70+
transform: Transformation to apply to attribution values.
71+
Options are ``"none"`` or ``"abs"``.
72+
hybrid_solver: If ``True``, handle multi-objective models differently.
73+
74+
Returns:
75+
A tuple containing:
76+
- jf: The Jacobian matrix of shape (num_samples, output_dim, input_dim)
77+
- jhatg: The pseudo-inverse of the Jacobian matrix
78+
The result is effectively a ``(num_inputs, num_features)`` attribution map.
6379
"""
6480
assert aggregate in ["mean", "sum", "max"]
6581

cebra/models/multiobjective.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,13 +283,14 @@ def __init__(self,
283283

284284
if max_slice_dim != self.num_output:
285285
raise ValueError(
286-
f"The dimension of output {self.num_output} is different than the highest dimension of slices {max_slice_dim}."
287-
f"They need to have the same dimension.")
286+
f"The dimension of output {self.num_output} is different than the highest dimension of the slices ({max_slice_dim})."
287+
f"The output dimension and slice dimension need to have the same dimension."
288+
)
288289

289290
check_slices_for_gaps(self.feature_ranges)
290291

291292
if check_overlapping_feature_ranges(self.feature_ranges):
292-
print("Computing renormalize ranges...")
293+
print("Computing renormalized ranges...")
293294
self.renormalize_ranges = compute_renormalize_ranges(
294295
self.feature_ranges, sort=True)
295296
print("New ranges:", self.renormalize_ranges)
@@ -327,9 +328,12 @@ def forward(self, inputs):
327328

328329
if self.renormalize:
329330
if hasattr(self, "renormalize_ranges"):
330-
#TODO: does the order of the renormalize ranges matter??
331-
# I think it does, imagine that the renormalize ranges are (5, 10), (0, 5), then
332-
# when we do torch.cat() output will be wrong --> Renormalize ranges need to be ordered.
331+
if not all(self.renormalize_ranges[i].start <=
332+
self.renormalize_ranges[i + 1].start
333+
for i in range(len(self.renormalize_ranges) - 1)):
334+
raise ValueError(
335+
"The renormalize_ranges must be sorted by start index.")
336+
333337
output = [
334338
self._norm(output[:, slice_features])
335339
for slice_features in self.renormalize_ranges

0 commit comments

Comments
 (0)