diff --git a/vmap/structure.go b/vmap/structure.go index 4630544..71c7aa8 100644 --- a/vmap/structure.go +++ b/vmap/structure.go @@ -1,6 +1,7 @@ package vmap import ( + "bytes" "encoding/xml" "fmt" "strconv" @@ -132,36 +133,31 @@ type CreativeParameter struct { type Duration struct{ time.Duration } +var formatStrings = [...]string{"h", "m", "s", "ms"} + func (d *Duration) UnmarshalText(data []byte) error { - s := string(data) - s = strings.TrimSpace(s) - if s == "" { - *d = Duration{} - return nil - } - parts := strings.Split(s, ":") - if len(parts) != 3 { - return fmt.Errorf("invalid duration format: %s", s) + var sb bytes.Buffer + currentPart := 0 + + for i := 0; i < len(data); i++ { + b := data[i] + switch b { + case ':', '.': + if currentPart == 3 { + return fmt.Errorf("invalid duration format: %s", string(data)) + } + sb.WriteString(formatStrings[currentPart]) + currentPart++ + case '1', '2', '3', '4', '5', '6', '7', '8', '9', '0': + sb.WriteByte(b) + } } - // TODO: Figure this part out - hours, minutes, seconds := parts[0], parts[1], parts[2] - var sb strings.Builder - dur := time.Duration(0) - sb.WriteString(hours) - sb.WriteString("h") - sb.WriteString(minutes) - sb.WriteString("m") - // TODO: Handle seconds with decimal - if strings.Contains(seconds, ".") { - parts := strings.Split(seconds, ".") - sb.WriteString(parts[0]) - sb.WriteString("s") - sb.WriteString(parts[1]) - sb.WriteString("ms") - } else { - sb.WriteString(seconds) - sb.WriteString("s") + sb.WriteString(formatStrings[currentPart]) + + if currentPart < 2 { + return fmt.Errorf("invalid duration format: %s", string(data)) } + dur, err := time.ParseDuration(sb.String()) if err != nil { return fmt.Errorf("error parsing duration: %w", err) diff --git a/vmap/structure_test.go b/vmap/structure_test.go index ddde73f..63264b5 100644 --- a/vmap/structure_test.go +++ b/vmap/structure_test.go @@ -123,6 +123,12 @@ func TestUnmarshalDuration(t *testing.T) { err = d.UnmarshalText([]byte("04:01:12.345")) is.NoErr(err) is.Equal(d.Duration, 4*time.Hour+1*time.Minute+12*time.Second+345*time.Millisecond) + + err = d.UnmarshalText([]byte("01:04:01:12.345")) + is.True(err != nil) + + err = d.UnmarshalText([]byte("01:04")) + is.True(err != nil) } func TestMarshalJson(t *testing.T) {