Skip to content

Batching GNNGraphs not compatible with Zygote gradient operation #594

@finn777-yea

Description

@finn777-yea

Hello! Many thanks for your great job!

I tried to build a model, where in its forward function I need to batch the input GNNGraphs. In the backprop process an error pops up:

ERROR: Mutating arrays is not supported -- called copyto!(Vector{Symbol}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations

I also wrote a dummy example to reproduce this error:

using Zygote
using GNNGraphs
using Flux
using CUDA

g1 = GNNGraph([1,2,3], [2,3,4])
g2 = GNNGraph([1,2,3], [2,4,5])

function test_fn(x)
    graphs = [g1, g2]
    gs = batch(graphs)

    return sum(gs.num_nodes)
end

# This will error
gradient(test_fn, 1.0)

I wonder if there is a way to workaround this? Or is it valid to simply use Zygote.@nograd batch? Any insights are welcome:)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions