Skip to content

Commit 66f8b48

Browse files
author
dPys
committed
Add doctest to non_overlapping_qspace_samples, made general-purpose
1 parent f5ca8c1 commit 66f8b48

File tree

1 file changed

+61
-10
lines changed

1 file changed

+61
-10
lines changed

dmriprep/utils/vectors.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -377,26 +377,70 @@ def bvecs2ras(affine, bvecs, norm=True, bvec_norm_epsilon=0.2):
377377
return rotated_bvecs
378378

379379

380-
def _nonoverlapping_qspace_samples(
381-
prediction_bval, prediction_bvec, all_bvals, all_bvecs, cutoff
382-
):
380+
def nonoverlapping_qspace_samples(sample_bval, sample_bvec, all_bvals,
381+
all_bvecs, cutoff=2):
383382
"""Ensure that none of the training samples are too close to the sample to predict.
384383
Parameters
384+
385+
Parameters
386+
----------
387+
sample_bval : int
388+
A single b-value sampled along the sphere.
389+
sample_bvec : int
390+
A single b-vector sampled along the sphere.
391+
Should correspond to `sample_bval`.
392+
all_bvals : ndarray
393+
A 1D vector of all b-values from the diffusion series.
394+
all_bvecs: ndarray
395+
A 3 x n vector of all vectors from the diffusion series,
396+
where n is the total number of samples.
397+
cutoff : float
398+
A minimal allowable q-space distance between points on
399+
the sphere.
400+
401+
Returns
402+
-------
403+
ok_samples : boolean ndarray
404+
True for q-vectors whose spatial distribution along
405+
the sphere is non-overlapping, else False.
406+
407+
Examples
408+
--------
409+
>>> bvec1 = np.array([1, 0, 0])
410+
>>> bvec2 = np.array([1, 0, 0])
411+
>>> bvec3 = np.array([0, 1, 0])
412+
>>> bval1 = 1000
413+
>>> bval2 = 1000
414+
>>> bval3 = 1000
415+
>>> all_bvals = np.array([0, bval2, bval3])
416+
>>> all_bvecs = np.array([np.zeros(3), bvec2, bvec3])
417+
>>> # Case 1: overlapping
418+
>>> nonoverlapping_qspace_samples(bval1, bvec1, all_bvals, all_bvecs, cutoff=2)
419+
array([ True, False, True])
420+
>>> all_bvals = np.array([0, bval1, bval2])
421+
>>> all_bvecs = np.array([np.zeros(3), bvec1, bvec2])
422+
>>> # Case 2: non-overlapping
423+
>>> nonoverlapping_qspace_samples(bval3, bvec3, all_bvals, all_bvecs, cutoff=2)
424+
array([ True, True, True])
385425
"""
386-
min_bval = min(min(all_bvals), prediction_bval)
426+
min_bval = min(min(all_bvals), sample_bval)
427+
max_bval = max(max(all_bvals), sample_bval)
428+
if min_bval == max_bval:
429+
raise ValueError('All b-values are identical')
430+
387431
all_qvals = np.sqrt(all_bvals - min_bval)
388-
prediction_qval = np.sqrt(prediction_bval - min_bval)
432+
sample_qval = np.sqrt(sample_bval - min_bval)
389433

390434
# Convert q values to percent of maximum qval
391-
max_qval = max(max(all_qvals), prediction_qval)
435+
max_qval = max(max(all_qvals), sample_qval)
392436
all_qvals_scaled = all_qvals / max_qval * 100
393437
scaled_qvecs = all_bvecs * all_qvals_scaled[:, np.newaxis]
394-
scaled_prediction_qvec = prediction_bvec * (prediction_qval / max_qval * 100)
438+
scaled_sample_qvec = sample_bvec * (sample_qval / max_qval * 100)
395439

396-
# Calculate the distance between the sampled qvecs and the prediction qvec
440+
# Calculate the distance between all qvecs and the sample qvec
397441
ok_samples = (
398-
np.linalg.norm(scaled_qvecs - scaled_prediction_qvec, axis=1) > cutoff
399-
) * (np.linalg.norm(scaled_qvecs + scaled_prediction_qvec, axis=1) > cutoff)
442+
np.linalg.norm(scaled_qvecs - scaled_sample_qvec, axis=1) > cutoff
443+
) * (np.linalg.norm(scaled_qvecs + scaled_sample_qvec, axis=1) > cutoff)
400444

401445
return ok_samples
402446

@@ -409,6 +453,9 @@ def _rasb_to_bvec_list(in_rasb):
409453
----------
410454
in_rasb : str or os.pathlike
411455
File path to a RAS-B gradient table.
456+
Returns
457+
-------
458+
List of b-vectors as floats.
412459
"""
413460
import numpy as np
414461

@@ -425,6 +472,10 @@ def _rasb_to_bval_floats(in_rasb):
425472
----------
426473
in_rasb : str or os.pathlike
427474
File path to a RAS-B gradient table.
475+
476+
Returns
477+
-------
478+
List of b-values as floats.
428479
"""
429480
import numpy as np
430481

0 commit comments

Comments
 (0)