diff --git a/torchquad/integration/grid_integrator.py b/torchquad/integration/grid_integrator.py index 2243399d..72350371 100644 --- a/torchquad/integration/grid_integrator.py +++ b/torchquad/integration/grid_integrator.py @@ -28,7 +28,7 @@ def f(integration_domain, N, requires_grad=False, backend=None): def _weights(self, N, dim, backend, requires_grad=False): return None - def integrate(self, fn, dim, N, integration_domain, backend): + def integrate(self, fn, dim, N, integration_domain, backend, args=None): """Integrate the passed function on the passed domain using a Composite Newton Cotes rule. The argument meanings are explained in the sub-classes. @@ -47,7 +47,7 @@ def integrate(self, fn, dim, N, integration_domain, backend): logger.debug("Evaluating integrand on the grid.") function_values, num_points = self.evaluate_integrand( - fn, grid_points, weights=self._weights(n_per_dim, dim, backend) + fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args ) self._nr_of_fevals = num_points @@ -139,7 +139,7 @@ def _adjust_N(dim, N): return N def get_jit_compiled_integrate( - self, dim, N=None, integration_domain=None, backend=None + self, dim, N=None, integration_domain=None, backend=None, args=None ): """Create an integrate function where the performance-relevant steps except the integrand evaluation are JIT compiled. Use this method only if the integrand cannot be compiled. @@ -151,6 +151,7 @@ def get_jit_compiled_integrate( N (int, optional): Total number of sample points to use for the integration. See the integrate method documentation for more details. integration_domain (list or backend tensor, optional): Integration domain, e.g. [[-1,1],[0,1]]. Defaults to [-1,1]^dim. It can also determine the numerical backend. backend (string, optional): Numerical backend. Defaults to integration_domain's backend if it is a tensor and otherwise to the backend from the latest call to set_up_backend or "torch" for backwards compatibility. + args (list or tuple, optional): Any arguments required by the function. Defaults to None. Returns: function(fn, integration_domain): JIT compiled integrate function where all parameters except the integrand and domain are fixed @@ -197,7 +198,7 @@ def get_jit_compiled_integrate( def compiled_integrate(fn, integration_domain): grid_points, hs, n_per_dim = jit_calculate_grid(N, integration_domain) function_values, _ = self.evaluate_integrand( - fn, grid_points, weights=self._weights(n_per_dim, dim, backend) + fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args ) return jit_calculate_result( function_values, dim, int(n_per_dim), hs, integration_domain @@ -238,6 +239,7 @@ def step3(function_values, hs, integration_domain): example_integrand, grid_points, weights=self._weights(n_per_dim, dim, backend), + args=args, ) # Trace the third step @@ -257,7 +259,7 @@ def step3(function_values, hs, integration_domain): def compiled_integrate(fn, integration_domain): grid_points, hs, _ = step1(integration_domain) function_values, _ = self.evaluate_integrand( - fn, grid_points, weights=self._weights(n_per_dim, dim, backend) + fn, grid_points, weights=self._weights(n_per_dim, dim, backend), args=args ) result = step3(function_values, hs, integration_domain) return result