77</p >
88
99<p align =" center " style =" text-align :center ;" >
10- <a href="https://pypi.org/project/pydvl/">
11- <img src="https://img.shields.io/pypi/v/pydvl.svg" alt="PyPI">
12- </a>
13- <a href="https://pypi.org/project/pydvl/">
14- <img src="https://img.shields.io/pypi/pyversions/pydvl.svg" alt="Version">
15- </a>
16- <a href="https://pydvl.org">
17- <img src="https://img.shields.io/badge/docs-All%20versions-009485" alt="documentation">
18- </a>
19- <a href="https://raw.githubusercontent.com/aai-institute/pyDVL/master/LICENSE">
20- <img alt="License" src="https://img.shields.io/pypi/l/pydvl">
21- </a>
22- <a href="https://github.com/aai-institute/pyDVL/actions/workflows/main.yaml">
23- <img src="https://github.com/aai-institute/pyDVL/actions/workflows/main.yaml/badge.svg" alt="Build status" >
24- </a>
25- <a href="https://codecov.io/gh/aai-institute/pyDVL">
26- <img src="https://codecov.io/gh/aai-institute/pyDVL/graph/badge.svg?token=VN7DNDE0FV"/>
27- </a>
28- <a href="https://zenodo.org/badge/latestdoi/354117916">
29- <img src="https://zenodo.org/badge/354117916.svg" alt="DOI">
30- </a>
10+ <a href="https://pypi.org/project/pydvl/"><img src="https://img.shields.io/pypi/v/pydvl.svg" alt="PyPI"></a>
11+ <a href="https://pypi.org/project/pydvl/"><img src="https://img.shields.io/pypi/pyversions/pydvl.svg" alt="Version"></a>
12+ <a href="https://pydvl.org"><img src="https://img.shields.io/badge/docs-All%20versions-009485" alt="documentation"></a>
13+ <a href="https://raw.githubusercontent.com/aai-institute/pyDVL/master/LICENSE"><img alt="License" src="https://img.shields.io/pypi/l/pydvl"></a>
14+ <a href="https://github.com/aai-institute/pyDVL/actions/workflows/main.yaml"><img src="https://github.com/aai-institute/pyDVL/actions/workflows/main.yaml/badge.svg" alt="Build status" ></a>
15+ <a href="https://codecov.io/gh/aai-institute/pyDVL"><img src="https://codecov.io/gh/aai-institute/pyDVL/graph/badge.svg?token=VN7DNDE0FV"/></a>
16+ <a href="https://zenodo.org/badge/latestdoi/354117916"><img src="https://zenodo.org/badge/354117916.svg" alt="DOI"></a>
3117</p >
3218
3319** pyDVL** collects algorithms for ** Data Valuation** and ** Influence Function** computation.
@@ -116,37 +102,34 @@ For influence computation, follow these steps:
116102 import torch
117103 from torch import nn
118104 from torch.utils.data import DataLoader, TensorDataset
119- from pydvl.influence import compute_influences, InversionMethod
120- from pydvl.influence.torch import TorchTwiceDifferentiable
105+
106+ from pydvl.influence.torch import DirectInfluence
107+ from pydvl.influence.torch.util import NestedTorchCatAggregator, TorchNumpyConverter
108+ from pydvl.influence import SequentialInfluenceCalculator
121109 ```
122110
1231112 . Create PyTorch data loaders for your train and test splits.
124112
125113 ``` python
126- torch.manual_seed(16 )
127-
128114 input_dim = (5 , 5 , 5 )
129115 output_dim = 3
116+ train_x = torch.rand((10 , * input_dim))
117+ train_y = torch.rand((10 , output_dim))
118+ test_x = torch.rand((5 , * input_dim))
119+ test_y = torch.rand((5 , output_dim))
130120
131- train_data_loader = DataLoader(
132- TensorDataset(torch.rand((10 , * input_dim)), torch.rand((10 , output_dim))),
133- batch_size = 2 ,
134- )
135- test_data_loader = DataLoader(
136- TensorDataset(torch.rand((5 , * input_dim)), torch.rand((5 , output_dim))),
137- batch_size = 1 ,
138- )
121+ train_data_loader = DataLoader(TensorDataset(train_x, train_y), batch_size = 2 )
122+ test_data_loader = DataLoader(TensorDataset(test_x, test_y), batch_size = 1 )
139123 ```
140124
1411253 . Instantiate your neural network model.
142126
143127 ``` python
144128 nn_architecture = nn.Sequential(
145- nn.Conv2d(in_channels = 5 , out_channels = 3 , kernel_size = 3 ),
146- nn.Flatten(),
147- nn.Linear(27 , 3 ),
129+ nn.Conv2d(in_channels = 5 , out_channels = 3 , kernel_size = 3 ),
130+ nn.Flatten(),
131+ nn.Linear(27 , 3 ),
148132 )
149- nn_architecture.eval()
150133 ```
151134
1521354 . Define your loss:
@@ -155,30 +138,38 @@ For influence computation, follow these steps:
155138 loss = nn.MSELoss()
156139 ```
157140
158- 5 . Wrap your model and loss in a ` TorchTwiceDifferentiable ` object.
141+ 5 . Instantiate an ` InfluenceFunctionModel ` and fit it to the training data
159142
160143 ``` python
161- model = TorchTwiceDifferentiable(nn_architecture, loss)
144+ infl_model = DirectInfluence(nn_architecture, loss, hessian_regularization = 0.01 )
145+ infl_model = infl_model.fit(train_data_loader)
162146 ```
163147
164- 6 . Compute influence factors by providing training data and inversion method.
165- Using the conjugate gradient algorithm, this would look like:
148+ 6 . For small input data call influence method on the fitted instance.
166149
167150 ``` python
168- influences = compute_influences(
169- model,
170- training_data = train_data_loader,
171- test_data = test_data_loader,
172- inversion_method = InversionMethod.Cg,
173- hessian_regularization = 1e-1 ,
174- maxiter = 200 ,
175- progress = True ,
176- )
151+ influences = infl_model.influences(test_x, test_y, train_x, train_y)
177152 ```
178153 The result is a tensor of shape ` (training samples x test samples) `
179154 that contains at index ` (i, j ` ) the influence of training sample ` i ` on
180155 test sample ` j ` .
181156
157+ 7 . For larger data, wrap the model into a
158+ calculator and call methods on the calculator.
159+ ``` python
160+ infl_calc = SequentialInfluenceCalculator(infl_model)
161+
162+ # Lazy object providing arrays batch-wise in a sequential manner
163+ lazy_influences = infl_calc.influences(test_data_loader, train_data_loader)
164+
165+ # Trigger computation and pull results to memory
166+ influences = lazy_influences.compute(aggregator = NestedTorchCatAggregator())
167+
168+ # Trigger computation and write results batch-wise to disk
169+ lazy_influences.to_zarr(" influences_result" , TorchNumpyConverter())
170+ ```
171+
172+
182173 The higher the absolute value of the influence of a training sample
183174 on a test sample, the more influential it is for the chosen test sample, model
184175 and data loaders. The sign of the influence determines whether it is
@@ -328,6 +319,7 @@ We currently implement the following papers:
328319 [ Scaling Up Influence Functions] ( http://arxiv.org/abs/2112.03052 ) .
329320 In Proceedings of the AAAI-22. arXiv, 2021.
330321
322+
331323# License
332324
333325pyDVL is distributed under
0 commit comments