-
Notifications
You must be signed in to change notification settings - Fork 12
Entropy-regularised Gromov-Wasserstein #165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 21 commits
2ef3e2b
11efd8c
3273976
0956c3b
c22d7e7
ff1a92c
267dfad
21609b0
9699e04
8510397
2f2428f
20d5885
df41c28
a7c1a38
19e4cab
56c4f9b
6e3ac4c
5c376ae
af2a493
6bc3127
a806f0f
f704397
71351b9
f2acc56
0635305
c3efe5a
39f0b36
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ jobs: | |
| strategy: | ||
| matrix: | ||
| version: | ||
| - '1.6' | ||
| - '1.8' | ||
| - '1' | ||
| - 'nightly' | ||
| os: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| name = "OptimalTransport" | ||
| uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" | ||
| authors = ["zsteve <[email protected]>"] | ||
| version = "0.3.20" | ||
| version = "0.3.21" | ||
|
|
||
| [deps] | ||
| ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,129 @@ | ||
| # This file is machine-generated - editing it directly is not advised | ||
zsteve marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| julia_version = "1.7.0" | ||
| manifest_format = "2.0" | ||
|
|
||
| [[deps.ArgTools]] | ||
| uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" | ||
|
|
||
| [[deps.Artifacts]] | ||
| uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" | ||
|
|
||
| [[deps.Base64]] | ||
| uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" | ||
|
|
||
| [[deps.CompilerSupportLibraries_jll]] | ||
| deps = ["Artifacts", "Libdl"] | ||
| uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" | ||
|
|
||
| [[deps.Conda]] | ||
| deps = ["Downloads", "JSON", "VersionParsing"] | ||
| git-tree-sha1 = "6e47d11ea2776bc5627421d59cdcc1296c058071" | ||
| uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d" | ||
| version = "1.7.0" | ||
|
|
||
| [[deps.Dates]] | ||
| deps = ["Printf"] | ||
| uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" | ||
|
|
||
| [[deps.Downloads]] | ||
| deps = ["ArgTools", "LibCURL", "NetworkOptions"] | ||
| uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" | ||
|
|
||
| [[deps.JSON]] | ||
| deps = ["Dates", "Mmap", "Parsers", "Unicode"] | ||
| git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" | ||
| uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" | ||
| version = "0.21.3" | ||
|
|
||
| [[deps.LibCURL]] | ||
| deps = ["LibCURL_jll", "MozillaCACerts_jll"] | ||
| uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" | ||
|
|
||
| [[deps.LibCURL_jll]] | ||
| deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] | ||
| uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" | ||
|
|
||
| [[deps.LibSSH2_jll]] | ||
| deps = ["Artifacts", "Libdl", "MbedTLS_jll"] | ||
| uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" | ||
|
|
||
| [[deps.Libdl]] | ||
| uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" | ||
|
|
||
| [[deps.LinearAlgebra]] | ||
| deps = ["Libdl", "libblastrampoline_jll"] | ||
| uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
|
|
||
| [[deps.MacroTools]] | ||
| deps = ["Markdown", "Random"] | ||
| git-tree-sha1 = "3d3e902b31198a27340d0bf00d6ac452866021cf" | ||
| uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" | ||
| version = "0.5.9" | ||
|
|
||
| [[deps.Markdown]] | ||
| deps = ["Base64"] | ||
| uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" | ||
|
|
||
| [[deps.MbedTLS_jll]] | ||
| deps = ["Artifacts", "Libdl"] | ||
| uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" | ||
|
|
||
| [[deps.Mmap]] | ||
| uuid = "a63ad114-7e13-5084-954f-fe012c677804" | ||
|
|
||
| [[deps.MozillaCACerts_jll]] | ||
| uuid = "14a3606d-f60d-562e-9121-12d972cd8159" | ||
|
|
||
| [[deps.NetworkOptions]] | ||
| uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" | ||
|
|
||
| [[deps.OpenBLAS_jll]] | ||
| deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] | ||
| uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" | ||
|
|
||
| [[deps.Parsers]] | ||
| deps = ["Dates"] | ||
| git-tree-sha1 = "85b5da0fa43588c75bb1ff986493443f821c70b7" | ||
| uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" | ||
| version = "2.2.3" | ||
|
|
||
| [[deps.Printf]] | ||
| deps = ["Unicode"] | ||
| uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" | ||
|
|
||
| [[deps.PyCall]] | ||
| deps = ["Conda", "Dates", "Libdl", "LinearAlgebra", "MacroTools", "Serialization", "VersionParsing"] | ||
| git-tree-sha1 = "1fc929f47d7c151c839c5fc1375929766fb8edcc" | ||
| uuid = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" | ||
| version = "1.93.1" | ||
|
|
||
| [[deps.Random]] | ||
| deps = ["SHA", "Serialization"] | ||
| uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
|
|
||
| [[deps.SHA]] | ||
| uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" | ||
|
|
||
| [[deps.Serialization]] | ||
| uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" | ||
|
|
||
| [[deps.Unicode]] | ||
| uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" | ||
|
|
||
| [[deps.VersionParsing]] | ||
| git-tree-sha1 = "58d6e80b4ee071f5efd07fda82cb9fbe17200868" | ||
| uuid = "81def892-9a0e-5fdd-b105-ffc91e053289" | ||
| version = "1.3.0" | ||
|
|
||
| [[deps.Zlib_jll]] | ||
| deps = ["Libdl"] | ||
| uuid = "83775a58-1f1d-513f-b197-d71354ab007a" | ||
|
|
||
| [[deps.libblastrampoline_jll]] | ||
| deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] | ||
| uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" | ||
|
|
||
| [[deps.nghttp2_jll]] | ||
| deps = ["Artifacts", "Libdl"] | ||
| uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| [deps] | ||
zsteve marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" | ||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,91 @@ | ||||||||
| # Gromov-Wasserstein solver | ||||||||
|
|
||||||||
| abstract type EntropicGromovWasserstein end | ||||||||
|
|
||||||||
| struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein | ||||||||
| alg_step::Sinkhorn | ||||||||
| end | ||||||||
|
|
||||||||
| """ | ||||||||
| entropic_gromov_wasserstein( | ||||||||
| μ, ν, Cμ, Cν, ε, alg=EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); | ||||||||
| atol = nothing, rtol = nothing, check_convergence = 10, maxiter = 1_000, kwargs... | ||||||||
| ) | ||||||||
| Computes the transport map for the entropically regularized Gromov-Wasserstein optimal transport problem with source and target | ||||||||
| marginals `μ` and `ν` and corresponding cost matrices `Cμ` and `Cν`. That is, we seek `γ` a local minimizer of | ||||||||
| ```math | ||||||||
| \\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, i', j'} |C^{(\\mu)}_{i,i'} - C^{(\\nu)}_{j,j'}|^2 \\gamma_{i,j} \\gamma_{i',j'} + \\varepsilon \\Omega(\\gamma), | ||||||||
| ``` | ||||||||
| where ``\\Omega(\\gamma)`` is the entropic regularization term, see e.g. [`sinkhorn`](@ref). | ||||||||
| This function employs the iterative method described in (Section 10.6.4, [^PC19]), which solves a series of Sinkhorn iteration sub-problems to arrive at a solution. Note that the Gromov-Wasserstein problem is non-convex owing to the cross-terms in the | ||||||||
| objective function, and thus in general one is guaranteed to arrive at a local optimum. | ||||||||
| Every `check_convergence` steps, the current iteration of `γ` is compared with `γ_prev` (the previous iteration from `check_convergence` ago). | ||||||||
| The quantity ``\\| \\gamma - \\gamma_\\text{prev} \\|_1`` is compared against `atol` and `rtol`. | ||||||||
| [^PC19]: Peyré, G. and Cuturi, M., 2019. Computational optimal transport: With applications to data science. Foundations and Trends® in Machine Learning, 11(5-6), pp.355-607. | ||||||||
| See also: [`sinkhorn`](@ref) | ||||||||
| """ | ||||||||
| function entropic_gromov_wasserstein( | ||||||||
| μ::AbstractVector, | ||||||||
| ν::AbstractVector, | ||||||||
| Cμ::AbstractMatrix, | ||||||||
| Cν::AbstractMatrix, | ||||||||
| ε::Real, | ||||||||
| alg::EntropicGromovWasserstein=EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); | ||||||||
| atol=nothing, | ||||||||
| rtol=nothing, | ||||||||
| check_convergence=10, | ||||||||
| maxiter::Int=1_000, | ||||||||
| kwargs..., | ||||||||
| ) | ||||||||
| T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) | ||||||||
| C = similar(Cμ, T, size(μ, 1), size(ν, 1)) | ||||||||
| tmp = similar(C) | ||||||||
| plan = similar(C) | ||||||||
| @. plan = μ * ν' | ||||||||
| plan_prev = similar(C) | ||||||||
| plan_prev .= plan | ||||||||
| norm_plan = sum(plan) | ||||||||
|
|
||||||||
| _atol = atol === nothing ? 0 : atol | ||||||||
| _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol | ||||||||
|
|
||||||||
| function get_new_cost!(C, plan, tmp, Cμ, Cν) | ||||||||
| A_batched_mul_B!(tmp, Cμ, plan) | ||||||||
| return A_batched_mul_B!(C, tmp, -4Cν) | ||||||||
|
||||||||
| # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? | ||||||||
| # added the factor of 4 here to ensure reproducibility for the same value of ε. | ||||||||
| # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 | ||||||||
| end | ||||||||
|
|
||||||||
| get_new_cost!(C, plan, tmp, Cμ, Cν) | ||||||||
| to_check_step = check_convergence | ||||||||
|
|
||||||||
| isconverged = false | ||||||||
| for iter in 1:maxiter | ||||||||
| # perform Sinkhorn algorithm | ||||||||
| solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) | ||||||||
| solve!(solver) | ||||||||
| # compute optimal transport plan | ||||||||
| plan = sinkhorn_plan(solver) | ||||||||
|
|
||||||||
| to_check_step -= 1 | ||||||||
| if to_check_step == 0 || iter == maxiter | ||||||||
| # reset counter | ||||||||
| to_check_step = check_convergence | ||||||||
| isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) | ||||||||
|
||||||||
| isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan) | |
| plan_prev .-= plan # used as a temporary array here to reduce allocations | |
| isconverged = sum(abs, plan_prev) < max(_atol, _rtol * norm_plan) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
norm_planis never updated it seems but always set tosum(plan)of the initial randomly initializedplan?
The initial plan is taken to be the independent coupling and here we only consider the balanced problem, so norm_plan should not change. I agree however this is a special case of the unbalanced problem where it would not be constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also avoid allocations here by writing:
Good catch, done
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| [deps] | ||
zsteve marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| using OptimalTransport | ||
|
|
||
| using Distances | ||
| using PythonOT: PythonOT | ||
|
|
||
| using Random | ||
| using Test | ||
| using LinearAlgebra | ||
|
|
||
| const POT = PythonOT | ||
|
|
||
| Random.seed!(100) | ||
|
|
||
| @testset "gromov.jl" begin | ||
| @testset "entropic_gromov_wasserstein" begin | ||
| M, N = 250, 200 | ||
|
|
||
| μ = fill(1 / M, M) | ||
| μ_spt = rand(M) | ||
| ν = fill(1 / N, N) | ||
| ν_spt = rand(N) | ||
|
|
||
| Cμ = pairwise(SqEuclidean(), μ_spt) | ||
| Cν = pairwise(SqEuclidean(), ν_spt) | ||
|
|
||
| γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence=10) | ||
| γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) | ||
|
|
||
| @test γ ≈ γ_pot rtol = 1e-6 | ||
| end | ||
| end |
Uh oh!
There was an error while loading. Please reload this page.