diff --git a/duration.go b/duration.go index 7b14a69..749c4b5 100644 --- a/duration.go +++ b/duration.go @@ -18,6 +18,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "math" "regexp" "strconv" "strings" @@ -60,7 +61,7 @@ var ( "w": hoursInDay * daysInWeek * time.Hour, } - durationMatcher = regexp.MustCompile(`(((?:-\s?)?\d+)\s*([A-Za-zµ]+))`) + durationMatcher = regexp.MustCompile(`^(((?:-\s?)?\d+)(\.\d+)?\s*([A-Za-zµ]+))`) ) // IsDuration returns true if the provided string is a valid duration @@ -100,11 +101,24 @@ func ParseDuration(cand string) (time.Duration, error) { var dur time.Duration ok := false + const expectGroups = 4 for _, match := range durationMatcher.FindAllStringSubmatch(cand, -1) { + if len(match) < expectGroups { + continue + } // remove possible leading - and spaces value, negative := strings.CutPrefix(match[2], "-") + // if the duration contains a decimal separator determine a divising factor + const neutral = 1.0 + divisor := neutral + decimal, hasDecimal := strings.CutPrefix(match[3], ".") + if hasDecimal { + divisor = math.Pow10(len(decimal)) + value += decimal // consider the value as an integer: will change units later on + } + // if the string is a valid duration, parse it factor, err := strconv.Atoi(strings.TrimSpace(value)) // converts string to int if err != nil { @@ -115,7 +129,7 @@ func ParseDuration(cand string) (time.Duration, error) { factor = -factor } - unit := strings.ToLower(strings.TrimSpace(match[3])) + unit := strings.ToLower(strings.TrimSpace(match[4])) for _, variants := range timeUnits { last := len(variants) - 1 @@ -124,6 +138,9 @@ func ParseDuration(cand string) (time.Duration, error) { for i, variant := range variants { if (last == i && strings.HasPrefix(unit, variant)) || strings.EqualFold(variant, unit) { ok = true + if divisor != neutral { + multiplier = time.Duration(float64(multiplier) / divisor) // convert to duration only after having reduced the scale + } dur += (time.Duration(factor) * multiplier) } } diff --git a/duration_test.go b/duration_test.go index de6a5f5..d76fa36 100644 --- a/duration_test.go +++ b/duration_test.go @@ -15,6 +15,7 @@ package strfmt import ( + "fmt" "testing" "time" @@ -230,3 +231,67 @@ func TestDeepCopyDuration(t *testing.T) { out3 := inNil.DeepCopy() assert.Nil(t, out3) } + +func TestIssue169FractionalDuration(t *testing.T) { + for _, tt := range []struct { + Input string + Expected string + ExpectError bool + }{ + { + Input: "1.5 h", + Expected: "1h30m0s", + }, + { + Input: "1.5 d", + Expected: "36h0m0s", + }, + { + Input: "3.14159 d", + Expected: "75h23m53.376s", + }, + { + Input: "- 3.14159 d", + Expected: "-75h23m53.376s", + }, + { + Input: "3.141.59 d", + ExpectError: true, + }, + { + Input: ".314159 d", + ExpectError: true, + }, + { + Input: "314159. d", + ExpectError: true, + }, + } { + fractionalDuration := tt + + if fractionalDuration.ExpectError { + t.Run(fmt.Sprintf("invalid fractional duration %s should NOT parse", fractionalDuration.Input), func(t *testing.T) { + t.Parallel() + + require.False(t, IsDuration(fractionalDuration.Input)) + }) + + continue + } + + t.Run(fmt.Sprintf("fractional duration %s should parse", fractionalDuration.Input), func(t *testing.T) { + t.Parallel() + + require.True(t, IsDuration(fractionalDuration.Input)) + + var d Duration + require.NoError(t, d.UnmarshalText([]byte(fractionalDuration.Input))) + + require.Equal(t, fractionalDuration.Expected, d.String()) + + dd, err := ParseDuration(fractionalDuration.Input) + require.NoError(t, err) + require.Equal(t, fractionalDuration.Expected, dd.String()) + }) + } +}