-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Labels
area: cpuCPU backend, element-wise ops, BLASCPU backend, element-wise ops, BLASarea: nnNN modules (Linear, Conv2D, RMSNorm, Embedding)NN modules (Linear, Conv2D, RMSNorm, Embedding)effort: 3Medium, ~1 dayMedium, ~1 daypriority: mediumNormal priorityNormal prioritystatus: confirmedVerified, ready for workVerified, ready for worktype: enhancementImprove existing featureImprove existing feature
Description
Background
Following feedback from @marcelloh in #14, our conv2d and maxpool2d functions have 10-14 arguments each, which exceeds the recommended maximum of 5.
Problem
Functions like conv2dInputBackwardFloat32 have 14 parameters:
func conv2dInputBackwardFloat32(
inputGrad, grad, kernel *tensor.RawTensor,
n, cIn, h, w, cOut, kH, kW, hOut, wOut, stride, padding int,
)Proposed Solution
Introduce parameter structs following the Burn framework pattern:
// ConvDims groups convolution dimension parameters
type ConvDims struct {
N, CIn, H, W int // Input dimensions
COut, KH, KW int // Kernel dimensions
HOut, WOut int // Output dimensions
Stride, Padding int // Convolution parameters
}
// PoolDims groups pooling dimension parameters
type PoolDims struct {
N, C, H, W int // Input dimensions
KH, KW int // Kernel dimensions
HOut, WOut int // Output dimensions
Stride, Padding int // Pooling parameters
}This reduces function signatures from 14 to 4-5 parameters:
func conv2dInputBackwardFloat32(
inputGrad, grad, kernel *tensor.RawTensor,
dims *ConvDims,
)Files to Update
-
internal/backend/cpu/conv2d.go -
internal/backend/cpu/conv2d_backward.go -
internal/backend/cpu/maxpool2d.go -
internal/autodiff/ops/conv2d.go -
internal/autodiff/ops/maxpool2d.go
Acceptance Criteria
-
ConvDimsstruct introduced -
PoolDimsstruct introduced - All conv2d functions updated to use structs
- All maxpool2d functions updated to use structs
- No performance regression (benchmark comparison)
- All tests pass
- golangci-lint: 0 issues
Risk Assessment
Low risk - This is a refactoring that doesn't change algorithmic logic.
Related
- Issue Avoid too complex code #14 - Original complexity analysis
Labels
enhancement, refactoring, code-quality
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
area: cpuCPU backend, element-wise ops, BLASCPU backend, element-wise ops, BLASarea: nnNN modules (Linear, Conv2D, RMSNorm, Embedding)NN modules (Linear, Conv2D, RMSNorm, Embedding)effort: 3Medium, ~1 dayMedium, ~1 daypriority: mediumNormal priorityNormal prioritystatus: confirmedVerified, ready for workVerified, ready for worktype: enhancementImprove existing featureImprove existing feature