@@ -239,35 +239,59 @@ def sample_polytope(
239239 Returns:
240240 (n, d) dim Tensor containing the resulting samples.
241241 """
242+ # Check that starting point satisfies the constraints.
243+ if not ((slack := A @ x0 - b ) <= 0 ).all ():
244+ raise ValueError (
245+ f"Starting point does not satisfy the constraints. Inputs: { A = } ,"
246+ f"{ b = } , { x0 = } , A@x0-b={ slack } ."
247+ )
248+ # Remove rows where all elements of A are 0. This avoids nan and infs later.
249+ # A may have zero rows in it when this is called from PolytopeSampler
250+ # with equality constraints (which are absorbed into A & b).
251+ non_zero_rows = torch .any (A != 0 , dim = - 1 )
252+ A = A [non_zero_rows ]
253+ b = b [non_zero_rows ]
254+
242255 n_tot = n + n0
243256 seed = seed if seed is not None else torch .randint (0 , 1000000 , (1 ,)).item ()
244257 with manual_seed (seed = seed ):
245258 rands = torch .rand (n_tot , dtype = A .dtype , device = A .device )
246259
247- # pre-sample samples from hypersphere
248- d = x0 .size (0 )
249- # uniform samples from unit ball in d dims
250- # increment seed by +1 to avoid correlation with step size, see #2156 for details
260+ # Sample uniformly from unit hypersphere in d dims.
261+ # Increment seed by +1 to avoid correlation with step size, see #2156 for details.
251262 Rs = sample_hypersphere (
252- d = d , n = n_tot , dtype = A .dtype , device = A .device , seed = seed + 1
263+ d = x0 . shape [ 0 ] , n = n_tot , dtype = A .dtype , device = A .device , seed = seed + 1
253264 ).unsqueeze (- 1 )
254265
255- # compute matprods in batch
266+ # Use batch operations for matrix multiplication.
256267 ARs = (A @ Rs ).squeeze (- 1 )
257268 out = torch .empty (n , A .size (- 1 ), dtype = A .dtype , device = A .device )
258269 x = x0 .clone ()
270+ large_constant = torch .finfo ().max
259271 for i , (ar , r , rnd ) in enumerate (zip (ARs , Rs , rands )):
260- # given x, the next point in the chain is x+alpha*r
261- # it also satisfies A(x+alpha*r)<=b which implies A*alpha*r<=b-Ax
272+ # Given x, the next point in the chain is x+alpha*r.
273+ # It must satisfy A(x+alpha*r)<=b, which implies A*alpha*r<=b-Ax,
262274 # so alpha<=(b-Ax)/ar for ar>0, and alpha>=(b-Ax)/ar for ar<0.
263- # b - A @ x is always >= 0, clamping for numerical tolerances
275+ # If x is at the boundary, b - Ax = 0. If ar > 0, then we must
276+ # have alpha <= 0. If ar < 0, we must have alpha >= 0.
277+ # ar == 0 is an unlikely event that provides no signal.
278+ # b - A @ x is always >= 0, clamping for numerical tolerances.
264279 w = (b - A @ x ).squeeze ().clamp (min = 0.0 ) / ar
265- pos = w >= 0
266- alpha_max = w [pos ].min ()
267- # important to include equality here in cases x is at the boundary
268- # of the polytope
269- neg = w <= 0
270- alpha_min = w [neg ].max ()
280+ # Find upper bound for alpha. If there are no constraints on
281+ # the upper bound of alpha, set it to a large value.
282+ pos = w > 0
283+ alpha_max = w [pos ].min ().item () if pos .any () else large_constant
284+ # Find lower bound for alpha.
285+ neg = w < 0
286+ alpha_min = w [neg ].max ().item () if neg .any () else - large_constant
287+ # Handle the boundary case.
288+ if (w_eq_0 := (w == 0 )).any ():
289+ # If ar > 0 at the boundary, alpha <= 0.
290+ if w_eq_0 .logical_and (ar > 0 ).any ():
291+ alpha_max = min (alpha_max , 0.0 )
292+ # If ar < 0 at the boundary, alpha >= 0.
293+ if w_eq_0 .logical_and (ar < 0 ).any ():
294+ alpha_min = max (alpha_min , 0.0 )
271295 # alpha~Unif[alpha_min, alpha_max]
272296 alpha = alpha_min + rnd * (alpha_max - alpha_min )
273297 x = x + alpha * r
0 commit comments