Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions news/fix-sasmodels.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
**Added:**

* <news item>

**Changed:**

* Refactored code utilizing sasmodels to use the new sasview api.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please move to fixed. This is a bug fix not a change. Reserve changes for changes in behavior that a user might need to know about.


**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* <news item>

**Security:**

* <news item>
19 changes: 16 additions & 3 deletions src/diffpy/srfit/sas/sasparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from diffpy.srfit.exceptions import ParseError
from diffpy.srfit.fitbase.profileparser import ProfileParser
from diffpy.srfit.sas.sasimport import sasimport
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted the import from this file as not used but I haven't deleted sasimport yet because the sas package in sasview still exists, and we still need to import from it.



class SASParser(ProfileParser):
Expand Down Expand Up @@ -102,10 +101,15 @@ def parseFile(self, filename):
Raises IOError if the file cannot be read
Raises ParseError if the file cannot be parsed
"""
import sasdata.dataloader.loader as ld

Loader = sasimport("sas.dataloader.loader").Loader
Loader = ld.Loader
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please could you make ld more readable. maybe sas_dataloader?

loader = Loader()

# Convert Path object to string if needed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load seems to expect a string now because it's calling .lower() on filename. In the source code, the traceback leads to the lookup() function in sasdata.data_util.registry, which calls:

path_lower = path.lower()

if not isinstance(filename, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just data = loader.load(str(filename)). I don't think we have to wrap this in a conditional.

filename = str(filename)

try:
data = loader.load(filename)
except RuntimeError as e:
Expand All @@ -118,7 +122,16 @@ def parseFile(self, filename):
self._meta["filename"] = filename
self._meta["datainfo"] = data

self._banks.append([data.x, data.y, data.dx, data.dy])
# Handle case where loader returns a list of data objects
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment. Just make you code as readable as possible.

if isinstance(data, list):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loader returns a list now. From source code:

def load(self, file_path_list: Union[List[Union[str, Path]], str, Path],
             format: Optional[Union[List[str], str]] = None
             ) -> List[Union[Data1D, Data2D]]:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we presumably don't need the conditional then. If it returns a list then just treat a list. We don't have to backwards compatible because we are only supporting recent versions of all dependencies.

# If it's a list, iterate through each data object
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove gratuitous comments

for data_obj in data:
self._banks.append(
[data_obj.x, data_obj.y, data_obj.dx, data_obj.dy]
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this else

# If it's a single data object, use it directly
self._banks.append([data.x, data.y, data.dx, data.dy])
self.selectBank(0)
return

Expand Down
10 changes: 9 additions & 1 deletion src/diffpy/srfit/sas/sasprofile.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,15 @@ def __init__(self, datainfo):
datainfo
The DataInfo object this wraps.
"""
self._datainfo = datainfo
# Handle case where datainfo is a list (new sasdata behavior)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above. We can now assume the datatype is a list.

if isinstance(datainfo, list):
if len(datainfo) == 0:
raise ValueError("Empty datainfo list provided")
# Use the first data object if it's a list
self._datainfo = datainfo[0]
else:
# Single datainfo object (legacy behavior)
self._datainfo = datainfo
Profile.__init__(self)

self._xobs = self._datainfo.x
Expand Down
8 changes: 5 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
import six

import diffpy.srfit.equation.literals as literals
from diffpy.srfit.sas.sasimport import sasimport

logger = logging.getLogger(__name__)


@lru_cache()
def has_sas():
try:
sasimport("sas.pr.invertor")
sasimport("sas.models")
import sas
import sasmodels

del sas
del sasmodels
return True
except ImportError:
return False
Expand Down
16 changes: 9 additions & 7 deletions tests/test_characteristicfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import numpy
import pytest

# Use the updated SasView model API to load models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove comments

from sasmodels.sasview_model import _make_standard_model

import diffpy.srfit.pdf.characteristicfunctions as cf
from diffpy.srfit.sas.sasimport import sasimport

# # Global variables to be assigned in setUp
# cf = None
Expand All @@ -34,7 +36,7 @@ def testSphere(sas_available):
pytest.skip("sas package not available")
radius = 25
# Calculate sphere cf from SphereModel
SphereModel = sasimport("sas.models.SphereModel").SphereModel
SphereModel = _make_standard_model("sphere")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems a bit odd that we are importing a private function. Are we sure this is the way we are supposed to be using the API?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll keep looking into this

model = SphereModel()
model.setParam("radius", radius)
ff = cf.SASCF("sphere", model)
Expand All @@ -56,10 +58,10 @@ def testSpheroid(sas_available):
prad = 20.9
erad = 33.114
# Calculate cf from EllipsoidModel
EllipsoidModel = sasimport("sas.models.EllipsoidModel").EllipsoidModel
EllipsoidModel = _make_standard_model("ellipsoid")
model = EllipsoidModel()
model.setParam("radius_a", prad)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not find any mention of the use of radius_a and radius_b in the documentation for sasview, even ones that dated back to version 4.x (the latest release is version 6.1.0). I can only assume that suitable replacements are radius_polar and radius_equatorial, which are the parameters that the ellipsoid model now uses.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EllipsoidModel must have a major and a minor axis, so these presumably refer to this? What are these axes called in EllipsoidModel?

model.setParam("radius_b", erad)
model.setParam("radius_polar", prad)
model.setParam("radius_equatorial", erad)
ff = cf.SASCF("spheroid", model)
r = numpy.arange(0, 100, 1 / numpy.pi, dtype=float)
fr1 = ff(r)
Expand All @@ -79,7 +81,7 @@ def testShell(sas_available):
radius = 19.2
thickness = 7.8
# Calculate cf from VesicleModel
VesicleModel = sasimport("sas.models.VesicleModel").VesicleModel
VesicleModel = _make_standard_model("vesicle")
model = VesicleModel()
model.setParam("radius", radius)
model.setParam("thickness", thickness)
Expand All @@ -103,7 +105,7 @@ def testCylinder(sas_available):
radius = 100
length = 30

CylinderModel = sasimport("sas.models.CylinderModel").CylinderModel
CylinderModel = _make_standard_model("cylinder")
model = CylinderModel()
model.setParam("radius", radius)
model.setParam("length", length)
Expand Down
20 changes: 12 additions & 8 deletions tests/test_sas.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import numpy
import pytest

# Use the updated SasView model API to load models
from sasmodels.sasview_model import _make_standard_model

from diffpy.srfit.sas import SASGenerator, SASParser, SASProfile
from diffpy.srfit.sas.sasimport import sasimport

# ----------------------------------------------------------------------------
# FIXME: adjust sensitivity of the pytest.approx statements when ready to test
Expand Down Expand Up @@ -113,7 +115,7 @@ def testParser(sas_available, datafile):
def test_generator(sas_available):
if not sas_available:
pytest.skip("sas package not available")
SphereModel = sasimport("sas.models.SphereModel").SphereModel
SphereModel = _make_standard_model("sphere")
model = SphereModel()
gen = SASGenerator("sphere", model)
for pname in model.params:
Expand All @@ -140,25 +142,27 @@ def test_generator(sas_available):
def testGenerator2(sas_available, datafile):
if not sas_available:
pytest.skip("sas package not available")
EllipsoidModel = sasimport("sas.models.EllipsoidModel").EllipsoidModel
EllipsoidModel = _make_standard_model("ellipsoid")
model = EllipsoidModel()
gen = SASGenerator("ellipsoid", model)

# Load the data using SAS tools
Loader = sasimport("sas.dataloader.loader").Loader
import sasdata.dataloader.loader as ld

Loader = ld.Loader
loader = Loader()
data = datafile("sas_ellipsoid_testdata.txt")
datainfo = loader.load(data)
datainfo = loader.load(str(data))
profile = SASProfile(datainfo)

gen.setProfile(profile)
gen.scale.value = 1.0
gen.radius_a.value = 20
gen.radius_b.value = 400
gen.radius_polar.value = 20
gen.radius_equatorial.value = 400
gen.background.value = 0.01

y = gen(profile.xobs)
diff = profile.yobs - y
res = numpy.dot(diff, diff)
assert 0 == pytest.approx(res)
assert 0 == pytest.approx(res, abs=1e-3)
return
Loading