Skip to content

Introduce ConvDims/PoolDims parameter structs to reduce argument counts #16

@kolkov

Description

@kolkov

Background

Following feedback from @marcelloh in #14, our conv2d and maxpool2d functions have 10-14 arguments each, which exceeds the recommended maximum of 5.

Problem

Functions like conv2dInputBackwardFloat32 have 14 parameters:

func conv2dInputBackwardFloat32(
    inputGrad, grad, kernel *tensor.RawTensor,
    n, cIn, h, w, cOut, kH, kW, hOut, wOut, stride, padding int,
)

Proposed Solution

Introduce parameter structs following the Burn framework pattern:

// ConvDims groups convolution dimension parameters
type ConvDims struct {
    N, CIn, H, W       int  // Input dimensions
    COut, KH, KW       int  // Kernel dimensions  
    HOut, WOut         int  // Output dimensions
    Stride, Padding    int  // Convolution parameters
}

// PoolDims groups pooling dimension parameters
type PoolDims struct {
    N, C, H, W         int  // Input dimensions
    KH, KW             int  // Kernel dimensions
    HOut, WOut         int  // Output dimensions
    Stride, Padding    int  // Pooling parameters
}

This reduces function signatures from 14 to 4-5 parameters:

func conv2dInputBackwardFloat32(
    inputGrad, grad, kernel *tensor.RawTensor,
    dims *ConvDims,
)

Files to Update

  • internal/backend/cpu/conv2d.go
  • internal/backend/cpu/conv2d_backward.go
  • internal/backend/cpu/maxpool2d.go
  • internal/autodiff/ops/conv2d.go
  • internal/autodiff/ops/maxpool2d.go

Acceptance Criteria

  • ConvDims struct introduced
  • PoolDims struct introduced
  • All conv2d functions updated to use structs
  • All maxpool2d functions updated to use structs
  • No performance regression (benchmark comparison)
  • All tests pass
  • golangci-lint: 0 issues

Risk Assessment

Low risk - This is a refactoring that doesn't change algorithmic logic.

Related

Labels

enhancement, refactoring, code-quality

Metadata

Metadata

Assignees

No one assigned

    Labels

    area: cpuCPU backend, element-wise ops, BLASarea: nnNN modules (Linear, Conv2D, RMSNorm, Embedding)effort: 3Medium, ~1 daypriority: mediumNormal prioritystatus: confirmedVerified, ready for worktype: enhancementImprove existing feature

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions