Skip to content

Commit 2e19d94

Browse files
mmoesmmboehm7
authored andcommitted
[SYSTEMDS-3182] Builtin ampute() for introducing missing values in data
Closes #2250.
1 parent 23c7e3b commit 2e19d94

File tree

5 files changed

+580
-0
lines changed

5 files changed

+580
-0
lines changed

scripts/builtin/ampute.dml

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
#-------------------------------------------------------------
2+
#
3+
# Licensed to the Apache Software Foundation (ASF) under one
4+
# or more contributor license agreements. See the NOTICE file
5+
# distributed with this work for additional information
6+
# regarding copyright ownership. The ASF licenses this file
7+
# to you under the Apache License, Version 2.0 (the
8+
# "License"); you may not use this file except in compliance
9+
# with the License. You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing,
14+
# software distributed under the License is distributed on an
15+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16+
# KIND, either express or implied. See the License for the
17+
# specific language governing permissions and limitations
18+
# under the License.
19+
#
20+
#-------------------------------------------------------------
21+
22+
# This function injects missing values into a multivariate a given dataset, similarly to the ampute() method in R's MICE package.
23+
#
24+
# INPUT:
25+
# -------------------------------------------------------------------------------------
26+
# X a multivariate numeric dataset [shape: n-by-m]
27+
# prop a number in the (0, 1] range specifying the proportion of amputed rows across the entire dataset
28+
# patterns a pattern matrix of 0's and 1's [shape: k-by-m] where each row corresponds to a pattern. 0 indicates that a variable should have missing values and 1 indicating that a variable should remain complete
29+
# freq a vector [length: k] containing the relative frequency with which each pattern in the patterns matrix should occur
30+
# mech a string [either "MAR", "MNAR", or "MCAR"] specifying the missingness mechanism. Chosen "MAR" and "MNAR" settings will be overridden if a non-default weight matrix is specified
31+
# weights a weight matrix [shape: k-by-m], containing weights that will be used to calculate the weighted sum scores. Will be overridden if mech == "MCAR"
32+
# seed a manually defined seed for reproducible RNG
33+
34+
# -------------------------------------------------------------------------------------
35+
#
36+
# OUTPUT:
37+
# -------------------------------------------------------------------------------------
38+
# amputedX amputed output dataset
39+
# -------------------------------------------------------------------------------------
40+
41+
m_ampute = function(Matrix[Double] X,
42+
Double prop = 0.5,
43+
Matrix[Double] patterns = matrix(0, 0, 0),
44+
Matrix[Double] freq = matrix(0, 0, 0),
45+
String mech = "MAR",
46+
Matrix[Double] weights = matrix(0, 0, 0),
47+
Integer seed = -1) return(Matrix[Double] amputedX) {
48+
# 1. Validate inputs, and set defaults for any empty freq, patterns, or weights matrices:
49+
[freq, patterns, weights] = u_validateInputs(X, prop, freq, patterns, mech, weights) # FIX ME
50+
# freq = nfreq
51+
# patterns = npatterns
52+
# weights = nweights
53+
54+
numSamples = nrow(X)
55+
numFeatures = ncol(X)
56+
numPatterns = nrow(patterns)
57+
[groupAssignments, numPerGroup] = u_randomChoice(numSamples, freq, seed) # Assign samples to groups based on freq vector.
58+
amputedX = matrix(0, rows=numSamples, cols=numFeatures + 1) # Create array to hold output.
59+
60+
parfor (patternNum in 1:numPatterns, check=0) {
61+
groupSize = as.scalar(numPerGroup[patternNum])
62+
if (groupSize == 0) {
63+
print("ampute warning: Zero rows assigned to pattern " + patternNum + ". Consider increasing input data size or pattern frequency?")
64+
}
65+
else {
66+
# 2. Collect group examples and mapping to original indices:
67+
[groupSamples, backMapping] = u_getGroupSamples(X, groupAssignments, numSamples, groupSize, numFeatures, patternNum)
68+
69+
# 3. Get amputation probabilities:
70+
sumScores = groupSamples %*% t(weights[patternNum])
71+
probs = u_getProbs(sumScores, groupSize, prop)
72+
73+
# 4. Use probabilities to ampute pattern candidates:
74+
random = rand(rows=groupSize, cols=1, min=0, max=1, pdf="uniform", seed=seed)
75+
amputeds = (random <= probs) * (1 - patterns[patternNum]) # Obtains matrix with 1's at indices to ampute.
76+
while (FALSE) {} # FIX ME
77+
groupSamples = groupSamples + replace(target=amputeds, pattern=1, replacement=NaN)
78+
79+
# 5. Update output matrix:
80+
[start, end] = u_getBounds(numPerGroup, groupSize, patternNum)
81+
amputedX[start:end, ] = cbind(groupSamples, backMapping)
82+
}
83+
}
84+
85+
# 6. Return amputed data in original order:
86+
amputedX = order(target=amputedX, by=numFeatures + 1) # Sort by original indices.
87+
amputedX = amputedX[, 1:numFeatures] # Remove index column.
88+
}
89+
90+
u_validateInputs = function(Matrix[Double] X, Double prop, Matrix[Double] freq, Matrix[Double] patterns, String mech, Matrix[Double] weights)
91+
return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
92+
93+
errors = list()
94+
freqProvided = !u_isEmpty(freq) # FIX ME
95+
patternsProvided = !u_isEmpty(patterns) # FIX ME
96+
weightsProvided = !u_isEmpty(weights) # FIX ME
97+
98+
# About the input dataset:
99+
if (max(is.na(X)) == 1) {
100+
errors = append(errors, "Input dataset cannot contain any NaN values.")
101+
}
102+
if (ncol(X) < 2) {
103+
errors = append(errors, "Input dataset must contain at least two columns. Only contained " + ncol(X) + ". Missingness patterns require multiple variables to be properly generated.")
104+
}
105+
106+
# About mech:
107+
if (mech != "MAR" & mech != "MCAR" & mech != "MNAR") {
108+
errors = append(errors, "Invalid option provided for mech: " + mech + ".")
109+
}
110+
else if (weightsProvided & mech == "MCAR") {
111+
print("ampute warning: User-provided weights will be ignored when mechanism MCAR is chosen.")
112+
}
113+
114+
# About prop:
115+
if (!(0 < prop & prop <= 1)) {
116+
errors = append(errors, "Value of prop must be within the range of (0, 1]. Was " + prop + ".")
117+
}
118+
119+
# Set defaults for empty freq, patterns and weights matrices:
120+
numFeatures = ncol(X)
121+
[freq, patterns, weights] = u_handleDefaults(freq, patterns, weights, mech, numFeatures)
122+
123+
# About freq:
124+
if (nrow(freq) > 1 & ncol(freq) > 1) {
125+
errors = append(errors, "freq provided as matrix with dimensions [" + nrow(freq) + ', ' + ncol(freq) + "], but must be a vector.")
126+
}
127+
else if (ncol(freq) > 1) {
128+
freq = t(freq) # Transposes row to column vector for convenience.
129+
}
130+
if (length(freq) != nrow(patterns)) {
131+
errors = append(errors, "Length of freq must be equal to the number of rows in the patterns matrix. freq has length "
132+
+ length(freq) + " while patterns contains " + nrow(patterns) + " rows.")
133+
}
134+
if (length(freq) != nrow(weights)) {
135+
errors = append(errors, "Length of freq must be equal to the number of rows in the weights matrix. freq has length "
136+
+ length(freq) + " while weights contains " + nrow(weights) + " rows.")
137+
}
138+
if (abs(sum(freq) - 1) > 1e-7) {
139+
errors = append(errors, "Values in freq vector must approximately sum to 1. Sum was " + sum(freq) + ".")
140+
}
141+
142+
# About patterns
143+
if (ncol(X) != ncol(patterns)) {
144+
errors = append(errors, "Input dataset must contain the same number of columns as the patterns matrix. Dataset contains "
145+
+ ncol(X) + " columns while patterns contains " + ncol(patterns) + ".")
146+
}
147+
if (ncol(patterns) != ncol(weights)) {
148+
errors = append(errors, "The patterns matrix must contain the same number of columns as the weights matrix. The patterns matrix contains "
149+
+ ncol(patterns) + " columns while weights contains " + ncol(weights) + ".")
150+
}
151+
if (max(patterns != 0 & patterns != 1) > 0) {
152+
errorPatterns = rowMaxs(patterns > 1 | patterns < 0)
153+
errorPatterns = removeEmpty(target=seq(1, nrow(patterns)), margin="rows", select=errorPatterns)
154+
errorString = u_getErrorIndices(errorPatterns)
155+
errors = append(errors, "The patterns matrix must contain only values of 0 or 1. The following rows in patterns break this rule: " + errorString + ".")
156+
}
157+
if (sum(rowMins(patterns)) > 0) {
158+
errorPatterns = removeEmpty(target=seq(1, nrow(patterns)), margin="rows", select=rowMins(patterns) == 1)
159+
errorString = u_getErrorIndices(errorPatterns)
160+
errors = append(errors, "Each row in the patterns matrix must contain at least one value of 0. The following rows in patterns break this rule: " + errorString + ".")
161+
}
162+
163+
# About weights:
164+
if (mech != "MCAR" & sum(rowMaxs(weights)) < nrow(weights)) {
165+
errorWeights = removeEmpty(target=seq(1, nrow(weights)), margin="rows", select=rowMaxs(weights) == 1)
166+
errorString = u_getErrorIndices(errorWeights)
167+
errors = append(errors, "Indicated weights of all 0's for some patterns when mechanism isn't MCAR. The following rows in weights break this rule: " + errorString + ".")
168+
}
169+
if (ncol(X) != ncol(weights)) {
170+
errors = append(errors, "Input dataset must contain the same number of columns as the weights matrix. Dataset contains "
171+
+ ncol(X) + " columns while weights contains " + ncol(weights) + ".")
172+
}
173+
174+
# Collect errors, if any:
175+
if (length(errors) > 0) {
176+
errorStrings = ""
177+
for (i in 1:length(errors)) {
178+
errorStrings = errorStrings + "\nampute: " + as.scalar(errors[i])
179+
}
180+
stop(errorStrings)
181+
}
182+
}
183+
184+
u_handleDefaults = function(Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights, String mech, Integer numFeatures)
185+
return (Matrix[Double] freq, Matrix[Double] patterns, Matrix[Double] weights) {
186+
# Patterns: Default is a quadratic matrix wherein pattern i amputes feature i.
187+
empty = u_isEmpty(patterns)
188+
if (empty) { # FIX ME
189+
patterns = matrix(1, rows=numFeatures, cols=numFeatures) - diag(matrix(1, rows=numFeatures, cols=1))
190+
}
191+
192+
# Weights: Various defaults based on chosen missingness mechanism:
193+
numPatterns = nrow(patterns)
194+
empty = u_isEmpty(weights) # FIX ME
195+
if (mech == "MCAR") {
196+
weights = matrix(0, rows=numPatterns, cols=numFeatures) # MCAR: All 0's (weights don't matter). Overrides any provided weights.
197+
}
198+
else if (empty) { # FIX ME
199+
if (mech == "MAR") {
200+
weights = patterns # MAR: Missing features weighted with 0.
201+
}
202+
else {
203+
weights = 1 - patterns # MNAR case: Observed features weighted with 0.
204+
}
205+
}
206+
207+
# Frequencies: Uniform by default.
208+
empty = u_isEmpty(freq) # FIX ME
209+
if (empty) {
210+
freq = matrix(1 / numPatterns, rows=numPatterns, cols=1)
211+
}
212+
}
213+
214+
u_getErrorIndices = function(Matrix[Double] errorPatterns) return (String errorString) {
215+
errorString = ""
216+
for (i in 1:length(errorPatterns)) {
217+
errorString = errorString + as.integer(as.scalar(errorPatterns[i]))
218+
if (i < length(errorPatterns)) {
219+
errorString = errorString + ", "
220+
}
221+
}
222+
}
223+
224+
u_isEmpty = function(Matrix[Double] X) return (Boolean emptiness) {
225+
emptiness = length(X) == 0
226+
}
227+
228+
# Assigns numSamples to a number of catagories based on the frequencies provided in freq.
229+
u_randomChoice = function(Integer numSamples, Matrix[Double] freq, Double seed = -1)
230+
return (Matrix[Double] groupAssignments, Matrix[Double] groupCounts) {
231+
numGroups = length(freq)
232+
if (numGroups == 1) { # Assigns all samples to the same group.
233+
groupCounts = matrix(numSamples, rows=1, cols=1)
234+
groupAssignments = matrix(1, rows=numSamples, cols=1)
235+
}
236+
else { # Assigns based on cumulative probability thresholds:
237+
cumSum = rbind(matrix(0, rows=1, cols=1), cumsum(freq)) # For, e.g., freq == [0.1, 0.4, 0.5], we get cumSum = [0.0, 0.1, 0.5, 1.0].
238+
random = rand(rows=numSamples, cols=1, min=0, max=1, pdf="uniform", seed=seed)
239+
groupCounts = matrix(0, rows=numGroups, cols=1)
240+
groupAssignments = matrix(0, rows=numSamples, cols=1)
241+
242+
for (i in 1:numGroups) {
243+
assigned = (random >= cumSum[i]) & (random < cumSum[i + 1])
244+
while (FALSE) {} # FIX ME
245+
groupCounts[i] = sum(assigned)
246+
groupAssignments = groupAssignments + i * assigned
247+
}
248+
}
249+
}
250+
251+
u_getGroupSamples = function(Matrix[Double] X, Matrix[Double] groupAssignments, Integer numSamples, Integer groupSize, Integer numFeatures, Integer patternNum)
252+
return (Matrix[Double] groupSamples, Matrix[Double] backMapping) {
253+
mask = groupAssignments == patternNum
254+
groupSamples = removeEmpty(target=X, margin="rows", select=mask)
255+
backMapping = removeEmpty(target=seq(1, numSamples), margin="rows", select=mask)
256+
}
257+
258+
# Assigns amputation probabilities to each sample:
259+
u_getProbs = function(Matrix[Double] sumScores, Integer groupSize, Double prop)
260+
return(Matrix[Double] probs) {
261+
if (length(unique(sumScores)) == 0) { # Checks if weights are all the same value (including the zero-case), as is the case with, e.g., MCAR chosen.
262+
probs = matrix(prop, rows=groupSize, cols=1)
263+
}
264+
else {
265+
zScores = scale(X=sumScores)
266+
rounded = round(prop * 100) / 100 # Rounds to two decimals for numeric stability.
267+
probs = u_binaryShiftSearch(zScores=zScores, prop=rounded)
268+
}
269+
}
270+
271+
# Performs a binary search for the optimum shift transformation to the weighted sum scores in order to obtain the desired missingness proportion.
272+
u_binaryShiftSearch = function(Matrix[Double] zScores, Double prop)
273+
return (Matrix[Double] probsArray) {
274+
shift = 0
275+
counter = 0
276+
probsArray = zScores
277+
currentProb = NaN
278+
lowerRange = -3
279+
upperRange = 3
280+
epsilon = 0.001
281+
maxIter = 100
282+
283+
while (counter < maxIter & (is.na(currentProb) | abs(currentProb - prop) >= epsilon)) {
284+
counter += 1
285+
shift = lowerRange + (upperRange - lowerRange) / 2
286+
probsArray = u_sigmoid(zScores + shift) # Calculates Right-Sigmoid probability (R implementation's default).
287+
currentProb = mean(probsArray)
288+
if (currentProb - prop > 0) {
289+
upperRange = shift
290+
}
291+
else {
292+
lowerRange = shift
293+
}
294+
}
295+
}
296+
297+
u_sigmoid = function(Matrix[Double] X)
298+
return (Matrix[Double] sigmoided) {
299+
sigmoided = 1 / (1 + exp(-X))
300+
}
301+
302+
u_getBounds = function(Matrix[Double] numPerGroup, Integer groupSize, Integer patternNum)
303+
return(Integer start, Integer end) {
304+
if (patternNum == 1) {
305+
start = 1
306+
}
307+
else {
308+
start = sum(numPerGroup[1:(patternNum - 1), ]) + 1
309+
}
310+
end = start + groupSize - 1
311+
}

src/main/java/org/apache/sysds/common/Builtins.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ public enum Builtins {
4747
ALS_DS("alsDS", true),
4848
ALS_PREDICT("alsPredict", true),
4949
ALS_TOPK_PREDICT("alsTopkPredict", true),
50+
AMPUTE("ampute", true),
5051
APPLY_PIPELINE("apply_pipeline", true),
5152
APPLY_SCHEMA("applySchema", false),
5253
ARIMA("arima", true),

0 commit comments

Comments
 (0)