|
| 1 | +#------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 4 | +# or more contributor license agreements. See the NOTICE file |
| 5 | +# distributed with this work for additional information |
| 6 | +# regarding copyright ownership. The ASF licenses this file |
| 7 | +# to you under the Apache License, Version 2.0 (the |
| 8 | +# "License"); you may not use this file except in compliance |
| 9 | +# with the License. You may obtain a copy of the License at |
| 10 | +# |
| 11 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +# |
| 13 | +# Unless required by applicable law or agreed to in writing, |
| 14 | +# software distributed under the License is distributed on an |
| 15 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 16 | +# KIND, either express or implied. See the License for the |
| 17 | +# specific language governing permissions and limitations |
| 18 | +# under the License. |
| 19 | +# |
| 20 | +#------------------------------------------------------------- |
| 21 | + |
| 22 | +# clustered gradients (no better than random sampling) |
| 23 | +# Accuracy (%): 88.39168202158463 |
| 24 | +# Accuracy (%): 83.76164636019249 |
| 25 | +# Accuracy (%): 96.04281828551373 |
| 26 | +# Accuracy (%): 78.03829220845705 |
| 27 | +# Accuracy (%): 100.0 |
| 28 | +# Accuracy (%): 66.66325381386301 |
| 29 | + |
| 30 | +# clustered data (no better than random sampling) |
| 31 | +# Accuracy (%): 86.95270685268054 |
| 32 | +# Accuracy (%): 83.3316269069315 |
| 33 | +# Accuracy (%): 96.472755988418 |
| 34 | +# Accuracy (%): 73.82000614313505 |
| 35 | +# Accuracy (%): 100.0 |
| 36 | +# Accuracy (%): 69.74505989556671 |
| 37 | + |
| 38 | +X = read("data/Adult_X.csv") |
| 39 | +y = read("data/Adult_y.csv") |
| 40 | +B = read("data/Adult_W.csv") |
| 41 | + |
| 42 | +[Xtrain,Xtest,ytrain,ytest] = split(X=X,Y=y,f=0.7,cont=FALSE,seed=7) |
| 43 | + |
| 44 | +sf = matrix("0.1 0.01 0.001", rows=3, cols=1) |
| 45 | +for(i in 1:nrow(sf)) { |
| 46 | + sfi = as.scalar(sf[i]); |
| 47 | + |
| 48 | + w = B[1:ncol(X), ]; |
| 49 | + icpt = as.scalar(B[nrow(B),]) |
| 50 | + Xgrad = sigmoid(Xtrain %*% w + icpt)-(ytrain == 1) |
| 51 | + [C,Y]=kmeans(X=Xgrad, k=sfi*nrow(Xtrain), seed=7) |
| 52 | + |
| 53 | + Yone = table(seq(1,nrow(Xtrain)),Y) |
| 54 | + I = rowIndexMax(Yone); #pick first in every cluster |
| 55 | + P = table(seq(1,nrow(I)), I, nrow(I), nrow(Xtrain)) |
| 56 | + P = removeEmpty(target=P, margin="rows"); |
| 57 | + Xtrain2 = P %*% Xtrain |
| 58 | + ytrain2 = P %*% ytrain |
| 59 | + B = multiLogReg(X=Xtrain2, Y=ytrain2, maxii=50, icpt=2, reg=0.001, verbose=FALSE); |
| 60 | + |
| 61 | + [M,yhat,acc] = multiLogRegPredict(X=Xtrain2, B=B, Y=ytrain2, verbose=TRUE); |
| 62 | + [M,yhat,acc] = multiLogRegPredict(X=Xtest, B=B, Y=ytest, verbose=TRUE); |
| 63 | +} |
| 64 | + |
0 commit comments