Skip to content

Commit fe27ab5

Browse files
committed
fix ICA example
oops
1 parent 2ca3553 commit fe27ab5

File tree

1 file changed

+31
-15
lines changed

1 file changed

+31
-15
lines changed

src/MLJMultivariateStatsInterface.jl

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ Train the machine using `fit!(mach, rows=...)`.
556556
The fields of `fitted_params(mach)` are:
557557
558558
- `projection`: The estimated component matrix.
559+
559560
- `mean`: The estimated mean vector.
560561
561562
# Report
@@ -564,34 +565,42 @@ The fields of `report(mach)` are:
564565
565566
- `indim`: Dimension (number of columns) of the training data and new data to be transformed.
566567
- `outdim`: Dimension of transformed data.
568+
567569
- `mean`: The mean of the untransformed training data, of length `indim`.
568570
569571
# Examples
570572
571573
```
572574
using MLJ
573-
using LinearAlgebra
574575
575576
ICA = @load ICA pkg=MultivariateStats
576577
577-
time = 8 .\\ 0:2001
578+
times = range(0, 8, length=2000)
578579
579-
sine_wave = sin.(2*time)
580-
square_wave = sign.(sin.(3*time))
581-
sawtooth_wave = repeat(collect(0:10) / 4, 182)
582-
signal = [sine_wave, square_wave, sawtooth_wave]
583-
add_noise(x) = x + randn()
584-
signal = map((x -> add_noise.(x)), signal)
585-
signal = permutedims(hcat(signal...))'
580+
sine_wave = sin.(2*times)
581+
square_wave = sign.(sin.(3*times))
582+
sawtooth_wave = map(t -> mod(2t, 2) - 1, times)
583+
signals = hcat(sine_wave, square_wave, sawtooth_wave)
584+
noisy_signals = signals + 0.2*randn(size(signals))
586585
587586
mixing_matrix = [ 1 1 1; 0.5 2 1; 1.5 1 2]
588-
X = MLJ.table(signal * mixing_matrix)
587+
X = MLJ.table(noisy_signals*mixing_matrix)
589588
590-
model = ICA(k = 3, tol=0.1)
591-
mach = machine(model, X) |> fit! # this errors ERROR: MethodError: no method matching size(::MultivariateStats.ICA{Float64}, ::Int64)
589+
model = ICA(outdim = 3, tol=0.1)
590+
mach = machine(model, X) |> fit!
591+
592+
X_unmixed = transform(mach, X)
593+
594+
using Plots
595+
596+
plot(X.x2)
597+
plot(X.x2)
598+
plot(X.x3)
599+
600+
plot(X_unmixed.x1)
601+
plot(X_unmixed.x2)
602+
plot(X_unmixed.x3)
592603
593-
Xproj = transform(mach, X)
594-
@info sum(abs, Xproj - signal)
595604
```
596605
597606
See also
@@ -611,7 +620,14 @@ possible the degree to which the target classes are separable can be discrimated
611620
either for dimension reduction of the features (see transform below) or for probabilistic
612621
classification of the target (see predict below).
613622
614-
In the case of prediction, the class probability for a new observation reflects the proximity of that observation to training observations associated with that class, and how far away the observation is from those associated with other classes. Specifically, the distances, in the transformed (projected) space, of a new observation, from the centroid of each target class, is computed; the resulting vector of distances (times minus one) is passed to a softmax function to obtain a class probability prediction. Here "distance" is computed using a user-specified distance function.
623+
In the case of prediction, the class probability for a new observation reflects the
624+
proximity of that observation to training observations associated with that class, and how
625+
far away the observation is from those associated with other classes. Specifically, the
626+
distances, in the transformed (projected) space, of a new observation, from the centroid of
627+
each target class, is computed; the resulting vector of distances (times minus one) is
628+
passed to a softmax function to obtain a class probability prediction. Here "distance" is
629+
computed using a user-specified distance function.
630+
615631
# Training data
616632
617633
In MLJ or MLJBase, bind an instance `model` to data with

0 commit comments

Comments
 (0)